diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index a8f10684d5a2..2e517707ff77 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -60,6 +60,14 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } } + /** + * Remove all listeners and they won't receive any events. This method is thread-safe and can be + * called in any thread. + */ + final def removeAllListeners(): Unit = { + listenersPlusTimers.clear() + } + /** * This can be overridden by subclasses if there is any extra cleanup to do when removing a * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 851fa2334501..a5d6d6366ede 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,7 +38,9 @@ object MimaExcludes { lazy val v30excludes = v24excludes ++ Seq( ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this") ) // Exclude rules for 2.4.x diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5d0feecd2cc2..5a28870f5d3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -672,17 +672,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { */ private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = { val qe = session.sessionState.executePlan(command) - try { - val start = System.nanoTime() - // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(session, qe)(qe.toRdd) - val end = System.nanoTime() - session.listenerManager.onSuccess(name, qe, end - start) - } catch { - case e: Exception => - session.listenerManager.onFailure(name, qe, e) - throw e - } + // call `QueryExecution.toRDD` to trigger the execution of commands. + SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd) } /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fa14aa14ee96..0fb3301b3616 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3356,21 +3356,11 @@ class Dataset[T] private[sql]( * user-registered callback functions. */ private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = { - try { + SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) { qe.executedPlan.foreach { plan => plan.resetMetrics() } - val start = System.nanoTime() - val result = SQLExecution.withNewExecutionId(sparkSession, qe) { - action(qe.executedPlan) - } - val end = System.nanoTime() - sparkSession.listenerManager.onSuccess(name, qe, end - start) - result - } catch { - case e: Exception => - sparkSession.listenerManager.onFailure(name, qe, e) - throw e + action(qe.executedPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 439932b0cc3a..dda7cb55f539 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -58,7 +58,8 @@ object SQLExecution { */ def withNewExecutionId[T]( sparkSession: SparkSession, - queryExecution: QueryExecution)(body: => T): T = { + queryExecution: QueryExecution, + name: Option[String] = None)(body: => T): T = { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) val executionId = SQLExecution.nextExecutionId @@ -71,14 +72,35 @@ object SQLExecution { val callSite = sc.getCallSite() withSQLConfPropagated(sparkSession) { - sc.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + var ex: Option[Exception] = None + val startTime = System.nanoTime() try { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId = executionId, + description = callSite.shortForm, + details = callSite.longForm, + physicalPlanDescription = queryExecution.toString, + // `queryExecution.executedPlan` triggers query planning. If it fails, the exception + // will be caught and reported in the `SparkListenerSQLExecutionEnd` + sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), + time = System.currentTimeMillis())) body + } catch { + case e: Exception => + ex = Some(e) + throw e } finally { - sc.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + val endTime = System.nanoTime() + val event = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis()) + // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name` + // parameter. The `ExecutionListenerManager` only watches SQL executions with name. We + // can specify the execution name in more places in the future, so that + // `QueryExecutionListener` can track more cases. + event.executionName = name + event.duration = endTime - startTime + event.qe = queryExecution + event.executionFailure = ex + sc.listenerBus.post(event) } } } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index b58b8c6d45e5..c04a31c428d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.ui +import com.fasterxml.jackson.annotation.JsonIgnore import com.fasterxml.jackson.databind.JavaType import com.fasterxml.jackson.databind.`type`.TypeFactory import com.fasterxml.jackson.databind.annotation.JsonDeserialize @@ -24,8 +25,7 @@ import com.fasterxml.jackson.databind.util.Converter import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo} @DeveloperApi case class SparkListenerSQLExecutionStart( @@ -39,7 +39,22 @@ case class SparkListenerSQLExecutionStart( @DeveloperApi case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) - extends SparkListenerEvent + extends SparkListenerEvent { + + // The name of the execution, e.g. `df.collect` will trigger a SQL execution with name "collect". + @JsonIgnore private[sql] var executionName: Option[String] = None + + // The following 3 fields are only accessed when `executionName` is defined. + + // The duration of the SQL execution, in nanoseconds. + @JsonIgnore private[sql] var duration: Long = 0L + + // The `QueryExecution` instance that represents the SQL execution + @JsonIgnore private[sql] var qe: QueryExecution = null + + // The exception object that caused this execution to fail. None if the execution doesn't fail. + @JsonIgnore private[sql] var executionFailure: Option[Exception] = None +} /** * A message used to update SQL metric value for driver-side updates (which doesn't get reflected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 3a0db7e16c23..60bba5e10703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -266,8 +266,8 @@ abstract class BaseSessionStateBuilder( * This gets cloned from parent if available, otherwise a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { - parentState.map(_.listenerManager.clone()).getOrElse( - new ExecutionListenerManager(session.sparkContext.conf)) + parentState.map(_.listenerManager.clone(session)).getOrElse( + new ExecutionListenerManager(session, loadExtensions = true)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 2b46233e1a5d..1310fdfa1356 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -17,17 +17,16 @@ package org.apache.spark.sql.util -import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.JavaConverters._ -import scala.collection.mutable.ListBuffer -import scala.util.control.NonFatal - -import org.apache.spark.SparkConf import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd import org.apache.spark.sql.internal.StaticSQLConf._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ListenerBus, Utils} /** * :: Experimental :: @@ -75,10 +74,18 @@ trait QueryExecutionListener { */ @Experimental @InterfaceStability.Evolving -class ExecutionListenerManager private extends Logging { - - private[sql] def this(conf: SparkConf) = { - this() +// The `session` is used to indicate which session carries this listener manager, and we only +// catch SQL executions which are launched by the same session. +// The `loadExtensions` flag is used to indicate whether we should load the pre-defined, +// user-specified listeners during construction. We should not do it when cloning this listener +// manager, as we will copy all listeners to the cloned listener manager. +class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean) + extends Logging { + + private val listenerBus = new ExecutionListenerBus(session) + + if (loadExtensions) { + val conf = session.sparkContext.conf conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames => Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register) } @@ -88,82 +95,63 @@ class ExecutionListenerManager private extends Logging { * Registers the specified [[QueryExecutionListener]]. */ @DeveloperApi - def register(listener: QueryExecutionListener): Unit = writeLock { - listeners += listener + def register(listener: QueryExecutionListener): Unit = { + listenerBus.addListener(listener) } /** * Unregisters the specified [[QueryExecutionListener]]. */ @DeveloperApi - def unregister(listener: QueryExecutionListener): Unit = writeLock { - listeners -= listener + def unregister(listener: QueryExecutionListener): Unit = { + listenerBus.removeListener(listener) } /** * Removes all the registered [[QueryExecutionListener]]. */ @DeveloperApi - def clear(): Unit = writeLock { - listeners.clear() + def clear(): Unit = { + listenerBus.removeAllListeners() } /** * Get an identical copy of this listener manager. */ - @DeveloperApi - override def clone(): ExecutionListenerManager = writeLock { - val newListenerManager = new ExecutionListenerManager - listeners.foreach(newListenerManager.register) + private[sql] def clone(session: SparkSession): ExecutionListenerManager = { + val newListenerManager = new ExecutionListenerManager(session, loadExtensions = false) + listenerBus.listeners.asScala.foreach(newListenerManager.register) newListenerManager } +} - private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - readLock { - withErrorHandling { listener => - listener.onSuccess(funcName, qe, duration) - } - } - } - - private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - readLock { - withErrorHandling { listener => - listener.onFailure(funcName, qe, exception) - } - } - } - - private[this] val listeners = ListBuffer.empty[QueryExecutionListener] +private[sql] class ExecutionListenerBus(session: SparkSession) + extends SparkListener with ListenerBus[QueryExecutionListener, SparkListenerSQLExecutionEnd] { - /** A lock to prevent updating the list of listeners while we are traversing through them. */ - private[this] val lock = new ReentrantReadWriteLock() + session.sparkContext.listenerBus.addToSharedQueue(this) - private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = { - for (listener <- listeners) { - try { - f(listener) - } catch { - case NonFatal(e) => logWarning("Error executing query execution listener", e) - } - } + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case e: SparkListenerSQLExecutionEnd => postToAll(e) + case _ => } - /** Acquires a read lock on the cache for the duration of `f`. */ - private def readLock[A](f: => A): A = { - val rl = lock.readLock() - rl.lock() - try f finally { - rl.unlock() + override protected def doPostEvent( + listener: QueryExecutionListener, + event: SparkListenerSQLExecutionEnd): Unit = { + if (shouldReport(event)) { + val funcName = event.executionName.get + event.executionFailure match { + case Some(ex) => + listener.onFailure(funcName, event.qe, ex) + case _ => + listener.onSuccess(funcName, event.qe, event.duration) + } } } - /** Acquires a write lock on the cache for the duration of `f`. */ - private def writeLock[A](f: => A): A = { - val wl = lock.writeLock() - wl.lock() - try f finally { - wl.unlock() - } + private def shouldReport(e: SparkListenerSQLExecutionEnd): Boolean = { + // Only catch SQL execution with a name, and triggered by the same spark session that this + // listener manager belongs. + e.executionName.isDefined && e.qe.sparkSession.eq(this.session) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index e1b5eba53f06..6317cd28bcc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -155,6 +155,7 @@ class SessionStateSuite extends SparkFunSuite { assert(forkedSession ne activeSession) assert(forkedSession.listenerManager ne activeSession.listenerManager) runCollectQueryOn(forkedSession) + activeSession.sparkContext.listenerBus.waitUntilEmpty(1000) assert(collectorA.commands.length == 1) // forked should callback to A assert(collectorA.commands(0) == "collect") @@ -162,12 +163,14 @@ class SessionStateSuite extends SparkFunSuite { // => changes to forked do not affect original forkedSession.listenerManager.register(collectorB) runCollectQueryOn(activeSession) + activeSession.sparkContext.listenerBus.waitUntilEmpty(1000) assert(collectorB.commands.isEmpty) // original should not callback to B assert(collectorA.commands.length == 2) // original should still callback to A assert(collectorA.commands(1) == "collect") // <= changes to original do not affect forked activeSession.listenerManager.register(collectorC) runCollectQueryOn(forkedSession) + activeSession.sparkContext.listenerBus.waitUntilEmpty(1000) assert(collectorC.commands.isEmpty) // forked should not callback to C assert(collectorA.commands.length == 3) // forked should still callback to A assert(collectorB.commands.length == 1) // forked should still callback to B diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 30dca9497ddd..269600dd59cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -356,10 +356,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { .withColumn("b", udf1($"a", lit(10))) df.cache() df.write.saveAsTable("t") + sparkContext.listenerBus.waitUntilEmpty(1000) assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable") df.write.insertInto("t") + sparkContext.listenerBus.waitUntilEmpty(1000) assert(numTotalCachedHit == 2, "expected to be cached in insertInto") df.write.save(path.getCanonicalPath) + sparkContext.listenerBus.waitUntilEmpty(1000) assert(numTotalCachedHit == 3, "expected to be cached in save for native") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala index 08e40e28d3d5..08789e63fa7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution -import org.json4s.jackson.JsonMethods.parse +import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart +import org.apache.spark.sql.LocalSparkSession +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} +import org.apache.spark.sql.test.TestSparkSession import org.apache.spark.util.JsonProtocol -class SQLJsonProtocolSuite extends SparkFunSuite { +class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession { test("SparkPlanGraph backward compatibility: metadata") { val SQLExecutionStartJsonString = @@ -49,4 +51,29 @@ class SQLJsonProtocolSuite extends SparkFunSuite { new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) assert(reconstructedEvent == expectedEvent) } + + test("SparkListenerSQLExecutionEnd backward compatibility") { + spark = new TestSparkSession() + val qe = spark.sql("select 1").queryExecution + val event = SparkListenerSQLExecutionEnd(1, 10) + event.duration = 1000 + event.executionName = Some("test") + event.qe = qe + event.executionFailure = Some(new RuntimeException("test")) + val json = JsonProtocol.sparkEventToJson(event) + assert(json == parse( + """ + |{ + | "Event" : "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd", + | "executionId" : 1, + | "time" : 10 + |} + """.stripMargin)) + val readBack = JsonProtocol.sparkEventFromJson(json) + event.duration = 0 + event.executionName = None + event.qe = null + event.executionFailure = None + assert(readBack == event) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index a239e39d9c5a..e8710aeb40bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -48,6 +48,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { df.select("i").collect() df.filter($"i" > 0).count() + sparkContext.listenerBus.waitUntilEmpty(1000) assert(metrics.length == 2) assert(metrics(0)._1 == "collect") @@ -78,6 +79,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { val e = intercept[SparkException](df.select(errorUdf($"i")).collect()) + sparkContext.listenerBus.waitUntilEmpty(1000) assert(metrics.length == 1) assert(metrics(0)._1 == "collect") assert(metrics(0)._2.analyzed.isInstanceOf[Project]) @@ -103,10 +105,16 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + // Wait for the first `collect` to be caught by our listener. Otherwise the next `collect` will + // reset the plan metrics. + sparkContext.listenerBus.waitUntilEmpty(1000) df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + sparkContext.listenerBus.waitUntilEmpty(1000) assert(metrics.length == 3) assert(metrics(0) === 1) assert(metrics(1) === 1) @@ -154,6 +162,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { // For this simple case, the peakExecutionMemory of a stage should be the data size of the // aggregate operator, as we only have one memory consuming operator per stage. + sparkContext.listenerBus.waitUntilEmpty(1000) assert(metrics.length == 2) assert(metrics(0) == topAggDataSize) assert(metrics(1) == bottomAggDataSize) @@ -177,6 +186,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { withTempPath { path => spark.range(10).write.format("json").save(path.getCanonicalPath) + sparkContext.listenerBus.waitUntilEmpty(1000) assert(commands.length == 1) assert(commands.head._1 == "save") assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) @@ -187,6 +197,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { withTable("tab") { sql("CREATE TABLE tab(i long) using parquet") // adds commands(1) via onSuccess spark.range(10).write.insertInto("tab") + sparkContext.listenerBus.waitUntilEmpty(1000) assert(commands.length == 3) assert(commands(2)._1 == "insertInto") assert(commands(2)._2.isInstanceOf[InsertIntoTable]) @@ -197,6 +208,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { withTable("tab") { spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") + sparkContext.listenerBus.waitUntilEmpty(1000) assert(commands.length == 5) assert(commands(4)._1 == "saveAsTable") assert(commands(4)._2.isInstanceOf[CreateTable]) @@ -208,6 +220,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { spark.range(10).select($"id", $"id").write.insertInto("tab") } + sparkContext.listenerBus.waitUntilEmpty(1000) assert(exceptions.length == 1) assert(exceptions.head._1 == "insertInto") assert(exceptions.head._2 == e) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala index 4205e23ae240..da414f4311e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala @@ -20,26 +20,28 @@ package org.apache.spark.sql.util import java.util.concurrent.atomic.AtomicInteger import org.apache.spark._ +import org.apache.spark.sql.{LocalSparkSession, SparkSession} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.StaticSQLConf._ -class ExecutionListenerManagerSuite extends SparkFunSuite { +class ExecutionListenerManagerSuite extends SparkFunSuite with LocalSparkSession { import CountingQueryExecutionListener._ test("register query execution listeners using configuration") { val conf = new SparkConf(false) .set(QUERY_EXECUTION_LISTENERS, Seq(classOf[CountingQueryExecutionListener].getName())) + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() - val mgr = new ExecutionListenerManager(conf) + spark.sql("select 1").collect() + spark.sparkContext.listenerBus.waitUntilEmpty(1000) assert(INSTANCE_COUNT.get() === 1) - mgr.onSuccess(null, null, 42L) assert(CALLBACK_COUNT.get() === 1) - val clone = mgr.clone() + val cloned = spark.cloneSession() + cloned.sql("select 1").collect() + spark.sparkContext.listenerBus.waitUntilEmpty(1000) assert(INSTANCE_COUNT.get() === 1) - - clone.onSuccess(null, null, 42L) assert(CALLBACK_COUNT.get() === 2) }