Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.streaming

import java.util.UUID

import org.json4s.{JObject, JString}
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc}
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.annotation.Evolving
import org.apache.spark.scheduler.SparkListenerEvent

/**
* Interface for listening to events related to [[StreamingQuery StreamingQueries]].
* @note
* The methods are not thread-safe as they may be called from different threads.
*
* @since 3.5.0
*/
@Evolving
abstract class StreamingQueryListener extends Serializable {

import StreamingQueryListener._

/**
* Called when a query is started.
* @note
* This is called synchronously with
* [[org.apache.spark.sql.streaming.DataStreamWriter `DataStreamWriter.start()`]], that is,
* `onQueryStart` will be called on all listeners before `DataStreamWriter.start()` returns
* the corresponding [[StreamingQuery]]. Please don't block this method as it will block your
* query.
* @since 3.5.0
*/
def onQueryStarted(event: QueryStartedEvent): Unit

/**
* Called when there is some status update (ingestion rate updated, etc.)
*
* @note
* This method is asynchronous. The status in [[StreamingQuery]] will always be latest no
* matter when this method is called. Therefore, the status of [[StreamingQuery]] may be
* changed before/when you process the event. E.g., you may find [[StreamingQuery]] is
* terminated when you are processing `QueryProgressEvent`.
* @since 3.5.0
*/
def onQueryProgress(event: QueryProgressEvent): Unit

/**
* Called when the query is idle and waiting for new data to process.
* @since 3.5.0
*/
def onQueryIdle(event: QueryIdleEvent): Unit = {}

/**
* Called when a query is stopped, with or without error.
* @since 3.5.0
*/
def onQueryTerminated(event: QueryTerminatedEvent): Unit
}

/**
* Py4J allows a pure interface so this proxy is required.
*/
private[spark] trait PythonStreamingQueryListener {
import StreamingQueryListener._

def onQueryStarted(event: QueryStartedEvent): Unit

def onQueryProgress(event: QueryProgressEvent): Unit

def onQueryIdle(event: QueryIdleEvent): Unit

def onQueryTerminated(event: QueryTerminatedEvent): Unit
}

private[spark] class PythonStreamingQueryListenerWrapper(listener: PythonStreamingQueryListener)
extends StreamingQueryListener {
import StreamingQueryListener._

def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event)

def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event)

override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event)

def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event)
}

/**
* Companion object of [[StreamingQueryListener]] that defines the listener events.
* @since 3.5.0
*/
@Evolving
object StreamingQueryListener extends Serializable {

/**
* Base type of [[StreamingQueryListener]] events
* @since 3.5.0
*/
@Evolving
trait Event extends SparkListenerEvent

/**
* Event representing the start of a query
* @param id
* A unique query id that persists across restarts. See `StreamingQuery.id()`.
* @param runId
* A query id that is unique for every start/restart. See `StreamingQuery.runId()`.
* @param name
* User-specified name of the query, null if not specified.
* @param timestamp
* The timestamp to start a query.
* @since 3.5.0
*/
@Evolving
class QueryStartedEvent private[sql] (
val id: UUID,
val runId: UUID,
val name: String,
val timestamp: String)
extends Event
with Serializable {

def json: String = compact(render(jsonValue))

private def jsonValue: JValue = {
("id" -> JString(id.toString)) ~
("runId" -> JString(runId.toString)) ~
("name" -> JString(name)) ~
("timestamp" -> JString(timestamp))
}
}

/**
* Event representing any progress updates in a query.
* @param progress
* The query progress updates.
* @since 3.5.0
*/
@Evolving
class QueryProgressEvent private[sql] (val progress: StreamingQueryProgress)
extends Event
with Serializable {

def json: String = compact(render(jsonValue))

private def jsonValue: JValue = JObject("progress" -> progress.jsonValue)
}

/**
* Event representing that query is idle and waiting for new data to process.
*
* @param id
* A unique query id that persists across restarts. See `StreamingQuery.id()`.
* @param runId
* A query id that is unique for every start/restart. See `StreamingQuery.runId()`.
* @param timestamp
* The timestamp when the latest no-batch trigger happened.
* @since 3.5.0
*/
@Evolving
class QueryIdleEvent private[sql] (val id: UUID, val runId: UUID, val timestamp: String)
extends Event
with Serializable {

def json: String = compact(render(jsonValue))

private def jsonValue: JValue = {
("id" -> JString(id.toString)) ~
("runId" -> JString(runId.toString)) ~
("timestamp" -> JString(timestamp))
}
}

/**
* Event representing that termination of a query.
*
* @param id
* A unique query id that persists across restarts. See `StreamingQuery.id()`.
* @param runId
* A query id that is unique for every start/restart. See `StreamingQuery.runId()`.
* @param exception
* The exception message of the query if the query was terminated with an exception.
* Otherwise, it will be `None`.
* @param errorClassOnException
* The error class from the exception if the query was terminated with an exception which is a
* part of error class framework. If the query was terminated without an exception, or the
* exception is not a part of error class framework, it will be `None`.
* @since 3.5.0
*/
@Evolving
class QueryTerminatedEvent private[sql] (
val id: UUID,
val runId: UUID,
val exception: Option[String],
val errorClassOnException: Option[String])
extends Event
with Serializable {
// compatibility with versions in prior to 3.5.0
def this(id: UUID, runId: UUID, exception: Option[String]) = {
this(id, runId, exception, None)
}

def json: String = compact(render(jsonValue))

private def jsonValue: JValue = {
("id" -> JString(id.toString)) ~
("runId" -> JString(runId.toString)) ~
("exception" -> JString(exception.orNull)) ~
("errorClassOnException" -> JString(errorClassOnException.orNull))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@
package org.apache.spark.sql.streaming

import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}

import scala.collection.JavaConverters._

import com.google.protobuf.ByteString

import org.apache.spark.annotation.Evolving
import org.apache.spark.connect.proto.Command
import org.apache.spark.connect.proto.StreamingQueryManagerCommand
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.common.{InvalidPlanInput, StreamingListenerPacket}
import org.apache.spark.util.Utils

/**
* A class to manage all the [[StreamingQuery]] active in a `SparkSession`.
Expand All @@ -36,6 +41,15 @@ import org.apache.spark.sql.SparkSession
@Evolving
class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging {

// Mapping from id to StreamingQueryListener. There's another mapping from id to
// StreamingQueryListener on server side. This is used by removeListener() to find the id
// of previously added StreamingQueryListener and pass it to server side to find the
// corresponding listener on server side. We use id to StreamingQueryListener mapping
// here to make sure there's no hash collision as well as handling the case that adds and
// removes the same listener instance multiple times properly.
private lazy val listenerCache: ConcurrentMap[String, StreamingQueryListener] =
new ConcurrentHashMap()

/**
* Returns a list of active queries associated with this SQLContext
*
Expand Down Expand Up @@ -126,6 +140,56 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
executeManagerCmd(_.setResetTerminated(true))
}

/**
* Register a [[StreamingQueryListener]] to receive up-calls for life cycle events of
* [[StreamingQuery]].
*
* @since 3.5.0
*/
def addListener(listener: StreamingQueryListener): Unit = {
// TODO: [SPARK-44400] Improve the Listener to provide users a way to access the Spark session
// and perform arbitrary actions inside the Listener. Right now users can use
// `val spark = SparkSession.builder.getOrCreate()` to create a Spark session inside the
// Listener, but this is a legacy session instead of a connect remote session.
val id = UUID.randomUUID.toString
cacheListenerById(id, listener)
executeManagerCmd(
_.getAddListenerBuilder
.setListenerPayload(ByteString.copyFrom(Utils
.serialize(StreamingListenerPacket(id, listener)))))
}

/**
* Deregister a [[StreamingQueryListener]].
*
* @since 3.5.0
*/
def removeListener(listener: StreamingQueryListener): Unit = {
val id = getIdByListener(listener)
executeManagerCmd(
_.getRemoveListenerBuilder
.setListenerPayload(ByteString.copyFrom(Utils
.serialize(StreamingListenerPacket(id, listener)))))
removeCachedListener(id)
}

/**
* List all [[StreamingQueryListener]]s attached to this [[StreamingQueryManager]].
*
* @since 3.5.0
*/
def listListeners(): Array[StreamingQueryListener] = {
executeManagerCmd(_.setListListeners(true)).getListListeners.getListenersList.asScala.map {
listener =>
Utils
.deserialize[StreamingListenerPacket](
listener.getListenerPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
.listener
.asInstanceOf[StreamingQueryListener]
}.toArray
}

private def executeManagerCmd(
setCmdFn: StreamingQueryManagerCommand.Builder => Unit // Sets the command field, like stop().
): StreamingQueryManagerCommandResult = {
Expand All @@ -145,4 +209,17 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo

resp.getStreamingQueryManagerCommandResult
}

private def cacheListenerById(id: String, listener: StreamingQueryListener): Unit = {
listenerCache.putIfAbsent(id, listener)
}

private def getIdByListener(listener: StreamingQueryListener): String = {
listenerCache.forEach((k, v) => if (listener.equals(v)) return k)
throw InvalidPlanInput(s"No id with listener $listener is found.")
}

private def removeCachedListener(id: String): StreamingQueryListener = {
listenerCache.remove(id)
}
}
Loading