Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions core/src/main/scala/org/apache/spark/util/ListenerBus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
}
}

/**
* Remove all listeners and they won't receive any events. This method is thread-safe and can be
* called in any thread.
*/
final def removeAllListeners(): Unit = {
listenersPlusTimers.clear()
}

/**
* This can be overridden by subclasses if there is any extra cleanup to do when removing a
* listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus.
Expand Down
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ object MimaExcludes {
lazy val v30excludes = v24excludes ++ Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues")
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this")
)

// Exclude rules for 2.4.x
Expand Down
13 changes: 2 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -672,17 +672,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
*/
private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
val qe = session.sessionState.executePlan(command)
try {
val start = System.nanoTime()
// call `QueryExecution.toRDD` to trigger the execution of commands.
SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
val end = System.nanoTime()
session.listenerManager.onSuccess(name, qe, end - start)
} catch {
case e: Exception =>
session.listenerManager.onFailure(name, qe, e)
throw e
}
// call `QueryExecution.toRDD` to trigger the execution of commands.
SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd)
}

///////////////////////////////////////////////////////////////////////////////////////
Expand Down
14 changes: 2 additions & 12 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3356,21 +3356,11 @@ class Dataset[T] private[sql](
* user-registered callback functions.
*/
private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
try {
SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) {
qe.executedPlan.foreach { plan =>
plan.resetMetrics()
}
val start = System.nanoTime()
val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
action(qe.executedPlan)
}
val end = System.nanoTime()
sparkSession.listenerManager.onSuccess(name, qe, end - start)
result
} catch {
case e: Exception =>
sparkSession.listenerManager.onFailure(name, qe, e)
throw e
action(qe.executedPlan)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ object SQLExecution {
*/
def withNewExecutionId[T](
sparkSession: SparkSession,
queryExecution: QueryExecution)(body: => T): T = {
queryExecution: QueryExecution,
name: Option[String] = None)(body: => T): T = {
val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
val executionId = SQLExecution.nextExecutionId
Expand All @@ -71,14 +72,35 @@ object SQLExecution {
val callSite = sc.getCallSite()

withSQLConfPropagated(sparkSession) {
sc.listenerBus.post(SparkListenerSQLExecutionStart(
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
var ex: Option[Exception] = None
val startTime = System.nanoTime()
try {
sc.listenerBus.post(SparkListenerSQLExecutionStart(
executionId = executionId,
description = callSite.shortForm,
details = callSite.longForm,
physicalPlanDescription = queryExecution.toString,
// `queryExecution.executedPlan` triggers query planning. If it fails, the exception
// will be caught and reported in the `SparkListenerSQLExecutionEnd`
sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan),
time = System.currentTimeMillis()))
body
} catch {
case e: Exception =>
ex = Some(e)
throw e
} finally {
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
executionId, System.currentTimeMillis()))
val endTime = System.nanoTime()
val event = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis())
// Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name`
// parameter. The `ExecutionListenerManager` only watches SQL executions with name. We
// can specify the execution name in more places in the future, so that
// `QueryExecutionListener` can track more cases.
event.executionName = name
event.duration = endTime - startTime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duration used to be reported in nanos. Now it's millis. I would still report it as nanos if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch!

event.qe = queryExecution
event.executionFailure = ex
sc.listenerBus.post(event)
}
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

package org.apache.spark.sql.execution.ui

import com.fasterxml.jackson.annotation.JsonIgnore
import com.fasterxml.jackson.databind.JavaType
import com.fasterxml.jackson.databind.`type`.TypeFactory
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import com.fasterxml.jackson.databind.util.Converter

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo}

@DeveloperApi
case class SparkListenerSQLExecutionStart(
Expand All @@ -39,7 +39,22 @@ case class SparkListenerSQLExecutionStart(

@DeveloperApi
case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long)
extends SparkListenerEvent
extends SparkListenerEvent {

// The name of the execution, e.g. `df.collect` will trigger a SQL execution with name "collect".
@JsonIgnore private[sql] var executionName: Option[String] = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backward compatibility, I make these new fields var.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to be backwards compatible here? SHS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a developer api, which is public. The backward compatibility is not that strong, compared to end-user public APIs, but we should still keep them unchanged if not too hard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that said, a developer can write a spark listener and catch this event.


// The following 3 fields are only accessed when `executionName` is defined.

// The duration of the SQL execution, in nanoseconds.
@JsonIgnore private[sql] var duration: Long = 0L
Copy link
Contributor

@brkyvz brkyvz Oct 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you verify that the JsonIgnore annotation actually works? For some reason, I actually needed to annotate the class as

@JsonIgnoreProperties(Array("a", b", "c"))
class SomeClass {
  @JsonProperty("a") val a: ...
  @JsonProperty("b") val b: ...
}

the reason being Json4s understands that API better. I believe we use Json4s for all of these events

Copy link
Contributor Author

@cloud-fan cloud-fan Oct 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a test to verify it: https://github.com/apache/spark/pull/22674/files#diff-6fa1d00d1cb20554dda238f2a3bc3ecbR55

I also used @JsonIgnoreProperties before, when I put these fields in case class constructor. It seems we don't need @JsonIgnoreProperties when they are private vars.


// The `QueryExecution` instance that represents the SQL execution
@JsonIgnore private[sql] var qe: QueryExecution = null

// The exception object that caused this execution to fail. None if the execution doesn't fail.
@JsonIgnore private[sql] var executionFailure: Option[Exception] = None
}

/**
* A message used to update SQL metric value for driver-side updates (which doesn't get reflected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ abstract class BaseSessionStateBuilder(
* This gets cloned from parent if available, otherwise a new instance is created.
*/
protected def listenerManager: ExecutionListenerManager = {
parentState.map(_.listenerManager.clone()).getOrElse(
new ExecutionListenerManager(session.sparkContext.conf))
parentState.map(_.listenerManager.clone(session)).getOrElse(
new ExecutionListenerManager(session, loadExtensions = true))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

package org.apache.spark.sql.util

import java.util.concurrent.locks.ReentrantReadWriteLock
import scala.collection.JavaConverters._

import scala.collection.mutable.ListBuffer
import scala.util.control.NonFatal

import org.apache.spark.SparkConf
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.util.Utils
import org.apache.spark.util.{ListenerBus, Utils}

/**
* :: Experimental ::
Expand Down Expand Up @@ -75,10 +74,18 @@ trait QueryExecutionListener {
*/
@Experimental
@InterfaceStability.Evolving
class ExecutionListenerManager private extends Logging {

private[sql] def this(conf: SparkConf) = {
this()
// The `session` is used to indicate which session carries this listener manager, and we only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not a class doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constructor is private, so we should not make it visible in the class doc

// catch SQL executions which are launched by the same session.
// The `loadExtensions` flag is used to indicate whether we should load the pre-defined,
// user-specified listeners during construction. We should not do it when cloning this listener
// manager, as we will copy all listeners to the cloned listener manager.
class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we shall add param comments.

extends Logging {

private val listenerBus = new ExecutionListenerBus(session)

if (loadExtensions) {
val conf = session.sparkContext.conf
conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames =>
Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register)
}
Expand All @@ -88,82 +95,63 @@ class ExecutionListenerManager private extends Logging {
* Registers the specified [[QueryExecutionListener]].
*/
@DeveloperApi
def register(listener: QueryExecutionListener): Unit = writeLock {
listeners += listener
def register(listener: QueryExecutionListener): Unit = {
listenerBus.addListener(listener)
}

/**
* Unregisters the specified [[QueryExecutionListener]].
*/
@DeveloperApi
def unregister(listener: QueryExecutionListener): Unit = writeLock {
listeners -= listener
def unregister(listener: QueryExecutionListener): Unit = {
listenerBus.removeListener(listener)
}

/**
* Removes all the registered [[QueryExecutionListener]].
*/
@DeveloperApi
def clear(): Unit = writeLock {
listeners.clear()
def clear(): Unit = {
listenerBus.removeAllListeners()
}

/**
* Get an identical copy of this listener manager.
*/
@DeveloperApi
override def clone(): ExecutionListenerManager = writeLock {
val newListenerManager = new ExecutionListenerManager
listeners.foreach(newListenerManager.register)
private[sql] def clone(session: SparkSession): ExecutionListenerManager = {
val newListenerManager = new ExecutionListenerManager(session, loadExtensions = false)
listenerBus.listeners.asScala.foreach(newListenerManager.register)
newListenerManager
}
}

private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
readLock {
withErrorHandling { listener =>
listener.onSuccess(funcName, qe, duration)
}
}
}

private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
readLock {
withErrorHandling { listener =>
listener.onFailure(funcName, qe, exception)
}
}
}

private[this] val listeners = ListBuffer.empty[QueryExecutionListener]
private[sql] class ExecutionListenerBus(session: SparkSession)
extends SparkListener with ListenerBus[QueryExecutionListener, SparkListenerSQLExecutionEnd] {

/** A lock to prevent updating the list of listeners while we are traversing through them. */
private[this] val lock = new ReentrantReadWriteLock()
session.sparkContext.listenerBus.addToSharedQueue(this)

private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = {
for (listener <- listeners) {
try {
f(listener)
} catch {
case NonFatal(e) => logWarning("Error executing query execution listener", e)
}
}
override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
case e: SparkListenerSQLExecutionEnd => postToAll(e)
case _ =>
}

/** Acquires a read lock on the cache for the duration of `f`. */
private def readLock[A](f: => A): A = {
val rl = lock.readLock()
rl.lock()
try f finally {
rl.unlock()
override protected def doPostEvent(
listener: QueryExecutionListener,
event: SparkListenerSQLExecutionEnd): Unit = {
if (shouldReport(event)) {
val funcName = event.executionName.get
event.executionFailure match {
case Some(ex) =>
listener.onFailure(funcName, event.qe, ex)
case _ =>
listener.onSuccess(funcName, event.qe, event.duration)
}
}
}

/** Acquires a write lock on the cache for the duration of `f`. */
private def writeLock[A](f: => A): A = {
val wl = lock.writeLock()
wl.lock()
try f finally {
wl.unlock()
}
private def shouldReport(e: SparkListenerSQLExecutionEnd): Boolean = {
// Only catch SQL execution with a name, and triggered by the same spark session that this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is what bugs me. You are adding separation between the SparkSession and its listeners, to undo that here. It seems like a bit of a hassle to go through because you basically need async execution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea. Assuming we have many spark sessions, running queries at the same time. Each session sends query execution events to the central event bus, and sets up a listener to watch its own query execution events, asynchronously.

To make it work, the most straightforward way is to carry the session identifier in the events, and the listener only watch events with the expected session identifier.

Maybe a better way is to introduce session in the Spark core, so the listener framework can dispatch events w.r.t. session automatically. But that's a lot of work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we had the same problem in the StreamingQueryListener. You can check how we solved it in StreamExecution. Since each SparkSession will have its own ExecutionListenerManager, you may be able to only have the proper ExecutionListenerManager deal with its own messages.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brkyvz thanks for the information! It seems the StreamingQueryListener framework picks the same idea but the implementation is better. I'll update my PR accordingly.

// listener manager belongs.
e.executionName.isDefined && e.qe.sparkSession.eq(this.session)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,22 @@ class SessionStateSuite extends SparkFunSuite {
assert(forkedSession ne activeSession)
assert(forkedSession.listenerManager ne activeSession.listenerManager)
runCollectQueryOn(forkedSession)
activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
assert(collectorA.commands.length == 1) // forked should callback to A
assert(collectorA.commands(0) == "collect")

// independence
// => changes to forked do not affect original
forkedSession.listenerManager.register(collectorB)
runCollectQueryOn(activeSession)
activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
assert(collectorB.commands.isEmpty) // original should not callback to B
assert(collectorA.commands.length == 2) // original should still callback to A
assert(collectorA.commands(1) == "collect")
// <= changes to original do not affect forked
activeSession.listenerManager.register(collectorC)
runCollectQueryOn(forkedSession)
activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
assert(collectorC.commands.isEmpty) // forked should not callback to C
assert(collectorA.commands.length == 3) // forked should still callback to A
assert(collectorB.commands.length == 1) // forked should still callback to B
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
.withColumn("b", udf1($"a", lit(10)))
df.cache()
df.write.saveAsTable("t")
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable")
df.write.insertInto("t")
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(numTotalCachedHit == 2, "expected to be cached in insertInto")
df.write.save(path.getCanonicalPath)
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(numTotalCachedHit == 3, "expected to be cached in save for native")
}
}
Expand Down
Loading