diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index 9dd927084298..bb2f5a8ac809 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -125,6 +125,11 @@
net.sf.jpam
jpam
+
+ org.mockito
+ mockito-core
+ test
+
target/scala-${scala.binary.version}/classes
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index 9517a599be63..835cbc3c867b 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -36,8 +36,9 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, S
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
-import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab
+import org.apache.spark.sql.hive.thriftserver.ui.{HiveThriftServer2AppStatusStore, LiveExecutionData, LiveSessionData, ThriftServerTab}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
import org.apache.spark.util.{ShutdownHookManager, Utils}
/**
@@ -62,10 +63,13 @@ object HiveThriftServer2 extends Logging {
server.init(executionHive.conf)
server.start()
- listener = new HiveThriftServer2Listener(server, sqlContext.conf)
+ val kvstore = sqlContext.sparkContext.statusStore.store.asInstanceOf[ElementTrackingStore]
+ listener = new HiveThriftServer2Listener(kvstore, server, sqlContext.conf)
sqlContext.sparkContext.addSparkListener(listener)
uiTab = if (sqlContext.sparkContext.getConf.get(UI_ENABLED)) {
- Some(new ThriftServerTab(sqlContext.sparkContext))
+ Some(new ThriftServerTab(
+ new HiveThriftServer2AppStatusStore(kvstore, Some(listener)),
+ sqlContext.sparkContext))
} else {
None
}
@@ -101,10 +105,16 @@ object HiveThriftServer2 extends Logging {
server.init(executionHive.conf)
server.start()
logInfo("HiveThriftServer2 started")
- listener = new HiveThriftServer2Listener(server, SparkSQLEnv.sqlContext.conf)
+ val kvstore = SparkSQLEnv.sparkContext.statusStore.store
+ .asInstanceOf[ElementTrackingStore]
+ listener = new HiveThriftServer2Listener(
+ kvstore,
+ server,
+ SparkSQLEnv.sqlContext.conf)
SparkSQLEnv.sparkContext.addSparkListener(listener)
uiTab = if (SparkSQLEnv.sparkContext.getConf.get(UI_ENABLED)) {
- Some(new ThriftServerTab(SparkSQLEnv.sparkContext))
+ Some(new ThriftServerTab(new HiveThriftServer2AppStatusStore(kvstore, Some(listener)),
+ SparkSQLEnv.sparkContext))
} else {
None
}
@@ -125,16 +135,20 @@ object HiveThriftServer2 extends Logging {
val sessionId: String,
val startTimestamp: Long,
val ip: String,
- val userName: String) {
+ val userName: String) extends LiveEntity {
var finishTimestamp: Long = 0L
var totalExecution: Int = 0
- def totalTime: Long = {
- if (finishTimestamp == 0L) {
- System.currentTimeMillis - startTimestamp
- } else {
- finishTimestamp - startTimestamp
- }
+
+ override protected def doUpdate(): Any = {
+ new LiveSessionData(
+ sessionId,
+ startTimestamp,
+ ip,
+ userName,
+ finishTimestamp,
+ totalExecution)
}
+
}
private[thriftserver] object ExecutionState extends Enumeration {
@@ -143,10 +157,11 @@ object HiveThriftServer2 extends Logging {
}
private[thriftserver] class ExecutionInfo(
+ val execId: String,
val statement: String,
val sessionId: String,
val startTimestamp: Long,
- val userName: String) {
+ val userName: String) extends LiveEntity {
var finishTimestamp: Long = 0L
var closeTimestamp: Long = 0L
var executePlan: String = ""
@@ -154,12 +169,21 @@ object HiveThriftServer2 extends Logging {
var state: ExecutionState.Value = ExecutionState.STARTED
val jobId: ArrayBuffer[String] = ArrayBuffer[String]()
var groupId: String = ""
- def totalTime(endTime: Long): Long = {
- if (endTime == 0L) {
- System.currentTimeMillis - startTimestamp
- } else {
- endTime - startTimestamp
- }
+
+ override protected def doUpdate(): Any = {
+ new LiveExecutionData(
+ execId,
+ statement,
+ sessionId,
+ startTimestamp,
+ userName,
+ finishTimestamp,
+ closeTimestamp,
+ executePlan,
+ detail,
+ state,
+ jobId,
+ groupId)
}
}
@@ -168,6 +192,7 @@ object HiveThriftServer2 extends Logging {
* An inner sparkListener called in sc.stop to clean up the HiveThriftServer2
*/
private[thriftserver] class HiveThriftServer2Listener(
+ val kvstore: ElementTrackingStore,
val server: HiveServer2,
val conf: SQLConf) extends SparkListener {
@@ -179,35 +204,14 @@ object HiveThriftServer2 extends Logging {
private val retainedStatements = conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT)
private val retainedSessions = conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT)
- def getOnlineSessionNum: Int = synchronized {
- sessionList.count(_._2.finishTimestamp == 0)
+ kvstore.addTrigger(classOf[LiveSessionData], retainedSessions) { count =>
+ cleanupSession(count)
}
- def isExecutionActive(execInfo: ExecutionInfo): Boolean = {
- !(execInfo.state == ExecutionState.FAILED ||
- execInfo.state == ExecutionState.CANCELED ||
- execInfo.state == ExecutionState.CLOSED)
+ kvstore.addTrigger(classOf[LiveExecutionData], retainedStatements) { count =>
+ cleanupExecutions(count)
}
- /**
- * When an error or a cancellation occurs, we set the finishTimestamp of the statement.
- * Therefore, when we count the number of running statements, we need to exclude errors and
- * cancellations and count all statements that have not been closed so far.
- */
- def getTotalRunning: Int = synchronized {
- executionList.count {
- case (_, v) => isExecutionActive(v)
- }
- }
-
- def getSessionList: Seq[SessionInfo] = synchronized { sessionList.values.toSeq }
-
- def getSession(sessionId: String): Option[SessionInfo] = synchronized {
- sessionList.get(sessionId)
- }
-
- def getExecutionList: Seq[ExecutionInfo] = synchronized { executionList.values.toSeq }
-
override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized {
for {
props <- Option(jobStart.properties)
@@ -216,6 +220,7 @@ object HiveThriftServer2 extends Logging {
} {
info.jobId += jobStart.jobId.toString
info.groupId = groupId
+ updateLiveStore(info)
}
}
@@ -223,13 +228,14 @@ object HiveThriftServer2 extends Logging {
synchronized {
val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName)
sessionList.put(sessionId, info)
- trimSessionIfNecessary()
+ updateLiveStore(info)
}
}
def onSessionClosed(sessionId: String): Unit = synchronized {
sessionList(sessionId).finishTimestamp = System.currentTimeMillis
- trimSessionIfNecessary()
+ updateLiveStore(sessionList(sessionId))
+ sessionList.remove(sessionId)
}
def onStatementStart(
@@ -238,60 +244,86 @@ object HiveThriftServer2 extends Logging {
statement: String,
groupId: String,
userName: String = "UNKNOWN"): Unit = synchronized {
- val info = new ExecutionInfo(statement, sessionId, System.currentTimeMillis, userName)
+ val info = new ExecutionInfo(id, statement, sessionId, System.currentTimeMillis, userName)
info.state = ExecutionState.STARTED
executionList.put(id, info)
- trimExecutionIfNecessary()
sessionList(sessionId).totalExecution += 1
executionList(id).groupId = groupId
+ updateLiveStore(sessionList(sessionId))
+ updateLiveStore(executionList(id))
}
def onStatementParsed(id: String, executionPlan: String): Unit = synchronized {
executionList(id).executePlan = executionPlan
executionList(id).state = ExecutionState.COMPILED
+ updateLiveStore(executionList(id))
}
def onStatementCanceled(id: String): Unit = synchronized {
executionList(id).finishTimestamp = System.currentTimeMillis
executionList(id).state = ExecutionState.CANCELED
- trimExecutionIfNecessary()
+ updateLiveStore(executionList(id))
}
def onStatementError(id: String, errorMsg: String, errorTrace: String): Unit = synchronized {
executionList(id).finishTimestamp = System.currentTimeMillis
executionList(id).detail = errorMsg
executionList(id).state = ExecutionState.FAILED
- trimExecutionIfNecessary()
+ updateLiveStore(executionList(id))
}
def onStatementFinish(id: String): Unit = synchronized {
executionList(id).finishTimestamp = System.currentTimeMillis
executionList(id).state = ExecutionState.FINISHED
- trimExecutionIfNecessary()
+ updateLiveStore(executionList(id))
}
def onOperationClosed(id: String): Unit = synchronized {
executionList(id).closeTimestamp = System.currentTimeMillis
executionList(id).state = ExecutionState.CLOSED
+ updateLiveStore(executionList(id))
+ executionList.remove(id)
+ }
+
+ private def cleanupExecutions(count: Long): Unit = {
+ val countToDelete = calculateNumberToRemove(count, retainedStatements)
+ if (countToDelete <= 0L) {
+ return
+ }
+ val view = kvstore.view(classOf[LiveExecutionData]).index("execId").first(0L)
+ val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt) { j =>
+ j.finishTimestamp != 0
+ }
+ toDelete.foreach { j => kvstore.delete(j.getClass(), j.execId) }
}
- private def trimExecutionIfNecessary() = {
- if (executionList.size > retainedStatements) {
- val toRemove = math.max(retainedStatements / 10, 1)
- executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s =>
- executionList.remove(s._1)
- }
+ private def cleanupSession(count: Long): Unit = {
+ val countToDelete = calculateNumberToRemove(count, retainedSessions)
+ if (countToDelete <= 0L) {
+ return
+ }
+ val view = kvstore.view(classOf[LiveSessionData]).index("sessionId").first(0L)
+ val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt) { j =>
+ j.finishTimestamp != 0L
}
+ toDelete.foreach { j => kvstore.delete(j.getClass(), j.sessionId) }
}
- private def trimSessionIfNecessary() = {
- if (sessionList.size > retainedSessions) {
- val toRemove = math.max(retainedSessions / 10, 1)
- sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s =>
- sessionList.remove(s._1)
- }
+ /**
+ * Remove at least (retainedSize / 10) items to reduce friction. Because tracking may be done
+ * asynchronously, this method may return 0 in case enough items have been deleted already.
+ */
+ private def calculateNumberToRemove(dataSize: Long, retainedSize: Long): Long = {
+ if (dataSize > retainedSize) {
+ math.max(retainedSize / 10L, dataSize - retainedSize)
+ } else {
+ 0L
}
+ }
+ private def updateLiveStore(entity: LiveEntity): Unit = {
+ val now = System.nanoTime()
+ entity.write(kvstore, now)
}
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2AppStatusStore.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2AppStatusStore.scala
new file mode 100644
index 000000000000..634464326b43
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2AppStatusStore.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver.ui
+
+import com.fasterxml.jackson.annotation.JsonIgnore
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionState, HiveThriftServer2Listener}
+import org.apache.spark.status.KVUtils.KVIndexParam
+import org.apache.spark.util.kvstore.KVStore
+
+class HiveThriftServer2AppStatusStore(
+ store: KVStore,
+ val listener: Option[HiveThriftServer2Listener] = None) {
+
+ def getSessionList: Seq[LiveSessionData] = {
+ store.view(classOf[LiveSessionData]).asScala.toSeq
+ }
+
+ def getExecutionList: Seq[LiveExecutionData] = {
+ store.view(classOf[LiveExecutionData]).asScala.toSeq
+ }
+
+ def getOnlineSessionNum: Int = {
+ store.view(classOf[LiveSessionData]).asScala.count(_.finishTimestamp == 0)
+ }
+
+ def getSession(sessionId: String): Option[LiveSessionData] = {
+ try {
+ Some(store.read(classOf[LiveSessionData], sessionId))
+ } catch {
+ case _: NoSuchElementException => None
+ }
+ }
+
+ /**
+ * When an error or a cancellation occurs, we set the finishTimestamp of the statement.
+ * Therefore, when we count the number of running statements, we need to exclude errors and
+ * cancellations and count all statements that have not been closed so far.
+ */
+ def getTotalRunning: Int = {
+ store.view(classOf[LiveExecutionData]).asScala.count(isExecutionActive)
+ }
+
+ def isExecutionActive(execInfo: LiveExecutionData): Boolean = {
+ !(execInfo.state == ExecutionState.FAILED ||
+ execInfo.state == ExecutionState.CANCELED ||
+ execInfo.state == ExecutionState.CLOSED)
+ }
+}
+
+private[thriftserver] class LiveSessionData(
+ @KVIndexParam val sessionId: String,
+ val startTimestamp: Long,
+ val ip: String,
+ val userName: String,
+ val finishTimestamp: Long,
+ val totalExecution: Long) {
+
+ @JsonIgnore @KVIndexParam("sessionId")
+ def session: String = sessionId
+
+ def totalTime: Long = {
+ if (finishTimestamp == 0L) {
+ System.currentTimeMillis - startTimestamp
+ } else {
+ finishTimestamp - startTimestamp
+ }
+ }
+}
+
+private[thriftserver] class LiveExecutionData(
+ @KVIndexParam val execId: String,
+ val statement: String,
+ val sessionId: String,
+ val startTimestamp: Long,
+ val userName: String,
+ val finishTimestamp: Long,
+ val closeTimestamp: Long,
+ val executePlan: String,
+ val detail: String,
+ val state: ExecutionState.Value,
+ val jobId: ArrayBuffer[String],
+ val groupId: String) {
+
+ @JsonIgnore @KVIndexParam("execId")
+ def exec(): String = execId
+
+ def totalTime(endTime: Long): Long = {
+ if (endTime == 0L) {
+ System.currentTimeMillis - startTimestamp
+ } else {
+ endTime - startTimestamp
+ }
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala
index d22415709860..b5fef5daf50e 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala
@@ -28,7 +28,6 @@ import scala.xml.{Node, Unparsed}
import org.apache.commons.text.StringEscapeUtils
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, SessionInfo}
import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._
import org.apache.spark.ui._
import org.apache.spark.ui.UIUtils._
@@ -37,18 +36,18 @@ import org.apache.spark.util.Utils
/** Page for Spark Web UI that shows statistics of the thrift server */
private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("") with Logging {
- private val listener = parent.listener
+ private val store = parent.store
private val startTime = Calendar.getInstance().getTime()
/** Render the page */
def render(request: HttpServletRequest): Seq[Node] = {
val content =
- listener.synchronized { // make sure all parts in this page are consistent
+ store.synchronized { // make sure all parts in this page are consistent
generateBasicStats() ++
++
- {listener.getOnlineSessionNum} session(s) are online,
- running {listener.getTotalRunning} SQL statement(s)
+ {store.getOnlineSessionNum} session(s) are online,
+ running {store.getTotalRunning} SQL statement(s)
++
generateSessionStatsTable(request) ++
generateSQLStatsTable(request)
@@ -72,7 +71,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
/** Generate stats of batch statements of the thrift server program */
private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = {
- val numStatement = listener.getExecutionList.size
+ val numStatement = store.getExecutionList.size
val table = if (numStatement > 0) {
@@ -103,7 +102,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
Some(new SqlStatsPagedTable(
request,
parent,
- listener.getExecutionList,
+ store.getExecutionList,
"sqlserver",
UIUtils.prependBaseUri(request, parent.basePath),
parameterOtherTable,
@@ -138,7 +137,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
/** Generate stats of batch sessions of the thrift server program */
private def generateSessionStatsTable(request: HttpServletRequest): Seq[Node] = {
- val numSessions = listener.getSessionList.size
+ val numSessions = store.getSessionList.size
val table = if (numSessions > 0) {
val sessionTableTag = "sessionstat"
@@ -168,7 +167,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
Some(new SessionStatsPagedTable(
request,
parent,
- listener.getSessionList,
+ store.getSessionList,
"sqlserver",
UIUtils.prependBaseUri(request, parent.basePath),
parameterOtherTable,
@@ -205,7 +204,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
private[ui] class SqlStatsPagedTable(
request: HttpServletRequest,
parent: ThriftServerTab,
- data: Seq[ExecutionInfo],
+ data: Seq[LiveExecutionData],
subPath: String,
basePath: String,
parameterOtherTable: Iterable[String],
@@ -392,14 +391,14 @@ private[ui] class SqlStatsPagedTable(
private[ui] class SessionStatsPagedTable(
request: HttpServletRequest,
parent: ThriftServerTab,
- data: Seq[SessionInfo],
+ data: Seq[LiveSessionData],
subPath: String,
basePath: String,
parameterOtherTable: Iterable[String],
sessionStatsTableTag: String,
pageSize: Int,
sortColumn: String,
- desc: Boolean) extends PagedTable[SessionInfo] {
+ desc: Boolean) extends PagedTable[LiveSessionData] {
override val dataSource = new SessionStatsTableDataSource(data, pageSize, sortColumn, desc)
@@ -471,7 +470,7 @@ private[ui] class SessionStatsPagedTable(
}
- override def row(session: SessionInfo): Seq[Node] = {
+ override def row(session: LiveSessionData): Seq[Node] = {
val sessionLink = "%s/%s/session/?id=%s".format(
UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId)
@@ -490,11 +489,11 @@ private[ui] class SessionStatsPagedTable(
val jobId: Seq[String],
val duration: Long,
val executionTime: Long,
- val executionInfo: ExecutionInfo,
+ val executionInfo: LiveExecutionData,
val detail: String)
private[ui] class SqlStatsTableDataSource(
- info: Seq[ExecutionInfo],
+ info: Seq[LiveExecutionData],
pageSize: Int,
sortColumn: String,
desc: Boolean) extends PagedDataSource[SqlStatsTableRow](pageSize) {
@@ -513,7 +512,7 @@ private[ui] class SessionStatsPagedTable(
r
}
- private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = {
+ private def sqlStatsTableRow(executionInfo: LiveExecutionData): SqlStatsTableRow = {
val duration = executionInfo.totalTime(executionInfo.closeTimestamp)
val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp)
val detail = Option(executionInfo.detail).filter(!_.isEmpty)
@@ -552,10 +551,10 @@ private[ui] class SessionStatsPagedTable(
}
private[ui] class SessionStatsTableDataSource(
- info: Seq[SessionInfo],
+ info: Seq[LiveSessionData],
pageSize: Int,
sortColumn: String,
- desc: Boolean) extends PagedDataSource[SessionInfo](pageSize) {
+ desc: Boolean) extends PagedDataSource[LiveSessionData](pageSize) {
// Sorting SessionInfo data
private val data = info.sorted(ordering(sortColumn, desc))
@@ -564,7 +563,7 @@ private[ui] class SessionStatsPagedTable(
override def dataSize: Int = data.size
- override def sliceData(from: Int, to: Int): Seq[SessionInfo] = {
+ override def sliceData(from: Int, to: Int): Seq[LiveSessionData] = {
val r = data.slice(from, to)
_slicedStartTime = r.map(_.startTimestamp).toSet
r
@@ -573,8 +572,8 @@ private[ui] class SessionStatsPagedTable(
/**
* Return Ordering according to sortColumn and desc.
*/
- private def ordering(sortColumn: String, desc: Boolean): Ordering[SessionInfo] = {
- val ordering: Ordering[SessionInfo] = sortColumn match {
+ private def ordering(sortColumn: String, desc: Boolean): Ordering[LiveSessionData] = {
+ val ordering: Ordering[LiveSessionData] = sortColumn match {
case "User" => Ordering.by(_.userName)
case "IP" => Ordering.by(_.ip)
case "Session ID" => Ordering.by(_.sessionId)
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
index 8b275f8f7be0..839c11f80866 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
@@ -32,7 +32,7 @@ import org.apache.spark.util.Utils
private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
extends WebUIPage("session") with Logging {
- private val listener = parent.listener
+ private val store = parent.store
private val startTime = Calendar.getInstance().getTime()
/** Render the page */
@@ -41,8 +41,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
val content =
- listener.synchronized { // make sure all parts in this page are consistent
- val sessionStat = listener.getSession(parameterId).getOrElse(null)
+ store.synchronized { // make sure all parts in this page are consistent
+ val sessionStat = store.getSession(parameterId).getOrElse(null)
require(sessionStat != null, "Invalid sessionID[" + parameterId + "]")
generateBasicStats() ++
@@ -73,7 +73,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
/** Generate stats of batch statements of the thrift server program */
private def generateSQLStatsTable(request: HttpServletRequest, sessionID: String): Seq[Node] = {
- val executionList = listener.getExecutionList
+ val executionList = store.getExecutionList
.filter(_.sessionId == sessionID)
val numStatement = executionList.size
val table = if (numStatement > 0) {
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala
index 8efb2c3311cf..78434d944c65 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala
@@ -27,7 +27,9 @@ import org.apache.spark.ui.{SparkUI, SparkUITab}
* Spark Web UI tab that shows statistics of jobs running in the thrift server.
* This assumes the given SparkContext has enabled its SparkUI.
*/
-private[thriftserver] class ThriftServerTab(sparkContext: SparkContext)
+private[thriftserver] class ThriftServerTab(
+ val store: HiveThriftServer2AppStatusStore,
+ sparkContext: SparkContext)
extends SparkUITab(getSparkUI(sparkContext), "sqlserver") with Logging {
override val name = "JDBC/ODBC Server"
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2ListenerSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2ListenerSuite.scala
new file mode 100644
index 000000000000..eb981f330dac
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2ListenerSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.util.Properties
+
+import org.mockito.Mockito.{mock, RETURNS_SMART_NULLS}
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.scheduler.SparkListenerJobStart
+import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.HiveThriftServer2Listener
+import org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2AppStatusStore
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.status.ElementTrackingStore
+import org.apache.spark.util.kvstore.InMemoryStore
+
+class HiveThriftServer2ListenerSuite extends SparkFunSuite
+ with BeforeAndAfter with SharedSparkSession {
+
+ private var kvstore: ElementTrackingStore = _
+
+ after {
+ if (kvstore != null) {
+ kvstore.close()
+ kvstore = null
+ }
+ }
+
+ private def createProperties: Properties = {
+ val properties = new Properties()
+ properties.setProperty(SparkContext.SPARK_JOB_GROUP_ID, "groupId")
+ properties
+ }
+
+ private def createStatusStore(): HiveThriftServer2AppStatusStore = {
+ val conf = sqlContext.conf
+
+ kvstore = new ElementTrackingStore(new InMemoryStore, new SparkConf())
+ val server = mock(classOf[HiveThriftServer2], RETURNS_SMART_NULLS)
+ val listener = new HiveThriftServer2Listener(kvstore, server, conf)
+ new HiveThriftServer2AppStatusStore(kvstore, Some(listener))
+ }
+
+ test("listener events should store successfully") {
+ val statusStore = createStatusStore()
+ val listener = statusStore.listener.get
+
+ listener.onSessionCreated("localhost", "sessionId", "user")
+ listener.onStatementStart("id", "sessionId", "dummy query", "groupId", "user")
+ listener.onStatementParsed("id", "dummy plan")
+ listener.onJobStart(SparkListenerJobStart(
+ 0,
+ System.currentTimeMillis(),
+ Nil,
+ createProperties))
+ listener.onStatementFinish("id")
+ listener.onOperationClosed("id")
+
+ assert(statusStore.getOnlineSessionNum == 1)
+
+ listener.onSessionClosed("sessionId")
+
+ assert(statusStore.getOnlineSessionNum == 0)
+ assert(statusStore.getExecutionList.size == 1)
+
+ val storeExecData = statusStore.getExecutionList.head
+
+ assert(storeExecData.execId == "id")
+ assert(storeExecData.sessionId == "sessionId")
+ assert(storeExecData.executePlan == "dummy plan")
+ assert(storeExecData.jobId == Seq("0"))
+ }
+ }