diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 9dd927084298..bb2f5a8ac809 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -125,6 +125,11 @@ net.sf.jpam jpam + + org.mockito + mockito-core + test + target/scala-${scala.binary.version}/classes diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 9517a599be63..835cbc3c867b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -36,8 +36,9 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, S import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab +import org.apache.spark.sql.hive.thriftserver.ui.{HiveThriftServer2AppStatusStore, LiveExecutionData, LiveSessionData, ThriftServerTab} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity} import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -62,10 +63,13 @@ object HiveThriftServer2 extends Logging { server.init(executionHive.conf) server.start() - listener = new HiveThriftServer2Listener(server, sqlContext.conf) + val kvstore = sqlContext.sparkContext.statusStore.store.asInstanceOf[ElementTrackingStore] + listener = new HiveThriftServer2Listener(kvstore, server, sqlContext.conf) sqlContext.sparkContext.addSparkListener(listener) uiTab = if (sqlContext.sparkContext.getConf.get(UI_ENABLED)) { - Some(new ThriftServerTab(sqlContext.sparkContext)) + Some(new ThriftServerTab( + new HiveThriftServer2AppStatusStore(kvstore, Some(listener)), + sqlContext.sparkContext)) } else { None } @@ -101,10 +105,16 @@ object HiveThriftServer2 extends Logging { server.init(executionHive.conf) server.start() logInfo("HiveThriftServer2 started") - listener = new HiveThriftServer2Listener(server, SparkSQLEnv.sqlContext.conf) + val kvstore = SparkSQLEnv.sparkContext.statusStore.store + .asInstanceOf[ElementTrackingStore] + listener = new HiveThriftServer2Listener( + kvstore, + server, + SparkSQLEnv.sqlContext.conf) SparkSQLEnv.sparkContext.addSparkListener(listener) uiTab = if (SparkSQLEnv.sparkContext.getConf.get(UI_ENABLED)) { - Some(new ThriftServerTab(SparkSQLEnv.sparkContext)) + Some(new ThriftServerTab(new HiveThriftServer2AppStatusStore(kvstore, Some(listener)), + SparkSQLEnv.sparkContext)) } else { None } @@ -125,16 +135,20 @@ object HiveThriftServer2 extends Logging { val sessionId: String, val startTimestamp: Long, val ip: String, - val userName: String) { + val userName: String) extends LiveEntity { var finishTimestamp: Long = 0L var totalExecution: Int = 0 - def totalTime: Long = { - if (finishTimestamp == 0L) { - System.currentTimeMillis - startTimestamp - } else { - finishTimestamp - startTimestamp - } + + override protected def doUpdate(): Any = { + new LiveSessionData( + sessionId, + startTimestamp, + ip, + userName, + finishTimestamp, + totalExecution) } + } private[thriftserver] object ExecutionState extends Enumeration { @@ -143,10 +157,11 @@ object HiveThriftServer2 extends Logging { } private[thriftserver] class ExecutionInfo( + val execId: String, val statement: String, val sessionId: String, val startTimestamp: Long, - val userName: String) { + val userName: String) extends LiveEntity { var finishTimestamp: Long = 0L var closeTimestamp: Long = 0L var executePlan: String = "" @@ -154,12 +169,21 @@ object HiveThriftServer2 extends Logging { var state: ExecutionState.Value = ExecutionState.STARTED val jobId: ArrayBuffer[String] = ArrayBuffer[String]() var groupId: String = "" - def totalTime(endTime: Long): Long = { - if (endTime == 0L) { - System.currentTimeMillis - startTimestamp - } else { - endTime - startTimestamp - } + + override protected def doUpdate(): Any = { + new LiveExecutionData( + execId, + statement, + sessionId, + startTimestamp, + userName, + finishTimestamp, + closeTimestamp, + executePlan, + detail, + state, + jobId, + groupId) } } @@ -168,6 +192,7 @@ object HiveThriftServer2 extends Logging { * An inner sparkListener called in sc.stop to clean up the HiveThriftServer2 */ private[thriftserver] class HiveThriftServer2Listener( + val kvstore: ElementTrackingStore, val server: HiveServer2, val conf: SQLConf) extends SparkListener { @@ -179,35 +204,14 @@ object HiveThriftServer2 extends Logging { private val retainedStatements = conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) private val retainedSessions = conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) - def getOnlineSessionNum: Int = synchronized { - sessionList.count(_._2.finishTimestamp == 0) + kvstore.addTrigger(classOf[LiveSessionData], retainedSessions) { count => + cleanupSession(count) } - def isExecutionActive(execInfo: ExecutionInfo): Boolean = { - !(execInfo.state == ExecutionState.FAILED || - execInfo.state == ExecutionState.CANCELED || - execInfo.state == ExecutionState.CLOSED) + kvstore.addTrigger(classOf[LiveExecutionData], retainedStatements) { count => + cleanupExecutions(count) } - /** - * When an error or a cancellation occurs, we set the finishTimestamp of the statement. - * Therefore, when we count the number of running statements, we need to exclude errors and - * cancellations and count all statements that have not been closed so far. - */ - def getTotalRunning: Int = synchronized { - executionList.count { - case (_, v) => isExecutionActive(v) - } - } - - def getSessionList: Seq[SessionInfo] = synchronized { sessionList.values.toSeq } - - def getSession(sessionId: String): Option[SessionInfo] = synchronized { - sessionList.get(sessionId) - } - - def getExecutionList: Seq[ExecutionInfo] = synchronized { executionList.values.toSeq } - override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { for { props <- Option(jobStart.properties) @@ -216,6 +220,7 @@ object HiveThriftServer2 extends Logging { } { info.jobId += jobStart.jobId.toString info.groupId = groupId + updateLiveStore(info) } } @@ -223,13 +228,14 @@ object HiveThriftServer2 extends Logging { synchronized { val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) sessionList.put(sessionId, info) - trimSessionIfNecessary() + updateLiveStore(info) } } def onSessionClosed(sessionId: String): Unit = synchronized { sessionList(sessionId).finishTimestamp = System.currentTimeMillis - trimSessionIfNecessary() + updateLiveStore(sessionList(sessionId)) + sessionList.remove(sessionId) } def onStatementStart( @@ -238,60 +244,86 @@ object HiveThriftServer2 extends Logging { statement: String, groupId: String, userName: String = "UNKNOWN"): Unit = synchronized { - val info = new ExecutionInfo(statement, sessionId, System.currentTimeMillis, userName) + val info = new ExecutionInfo(id, statement, sessionId, System.currentTimeMillis, userName) info.state = ExecutionState.STARTED executionList.put(id, info) - trimExecutionIfNecessary() sessionList(sessionId).totalExecution += 1 executionList(id).groupId = groupId + updateLiveStore(sessionList(sessionId)) + updateLiveStore(executionList(id)) } def onStatementParsed(id: String, executionPlan: String): Unit = synchronized { executionList(id).executePlan = executionPlan executionList(id).state = ExecutionState.COMPILED + updateLiveStore(executionList(id)) } def onStatementCanceled(id: String): Unit = synchronized { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.CANCELED - trimExecutionIfNecessary() + updateLiveStore(executionList(id)) } def onStatementError(id: String, errorMsg: String, errorTrace: String): Unit = synchronized { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).detail = errorMsg executionList(id).state = ExecutionState.FAILED - trimExecutionIfNecessary() + updateLiveStore(executionList(id)) } def onStatementFinish(id: String): Unit = synchronized { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.FINISHED - trimExecutionIfNecessary() + updateLiveStore(executionList(id)) } def onOperationClosed(id: String): Unit = synchronized { executionList(id).closeTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.CLOSED + updateLiveStore(executionList(id)) + executionList.remove(id) + } + + private def cleanupExecutions(count: Long): Unit = { + val countToDelete = calculateNumberToRemove(count, retainedStatements) + if (countToDelete <= 0L) { + return + } + val view = kvstore.view(classOf[LiveExecutionData]).index("execId").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt) { j => + j.finishTimestamp != 0 + } + toDelete.foreach { j => kvstore.delete(j.getClass(), j.execId) } } - private def trimExecutionIfNecessary() = { - if (executionList.size > retainedStatements) { - val toRemove = math.max(retainedStatements / 10, 1) - executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => - executionList.remove(s._1) - } + private def cleanupSession(count: Long): Unit = { + val countToDelete = calculateNumberToRemove(count, retainedSessions) + if (countToDelete <= 0L) { + return + } + val view = kvstore.view(classOf[LiveSessionData]).index("sessionId").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt) { j => + j.finishTimestamp != 0L } + toDelete.foreach { j => kvstore.delete(j.getClass(), j.sessionId) } } - private def trimSessionIfNecessary() = { - if (sessionList.size > retainedSessions) { - val toRemove = math.max(retainedSessions / 10, 1) - sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => - sessionList.remove(s._1) - } + /** + * Remove at least (retainedSize / 10) items to reduce friction. Because tracking may be done + * asynchronously, this method may return 0 in case enough items have been deleted already. + */ + private def calculateNumberToRemove(dataSize: Long, retainedSize: Long): Long = { + if (dataSize > retainedSize) { + math.max(retainedSize / 10L, dataSize - retainedSize) + } else { + 0L } + } + private def updateLiveStore(entity: LiveEntity): Unit = { + val now = System.nanoTime() + entity.write(kvstore, now) } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2AppStatusStore.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2AppStatusStore.scala new file mode 100644 index 000000000000..634464326b43 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2AppStatusStore.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.ui + +import com.fasterxml.jackson.annotation.JsonIgnore +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionState, HiveThriftServer2Listener} +import org.apache.spark.status.KVUtils.KVIndexParam +import org.apache.spark.util.kvstore.KVStore + +class HiveThriftServer2AppStatusStore( + store: KVStore, + val listener: Option[HiveThriftServer2Listener] = None) { + + def getSessionList: Seq[LiveSessionData] = { + store.view(classOf[LiveSessionData]).asScala.toSeq + } + + def getExecutionList: Seq[LiveExecutionData] = { + store.view(classOf[LiveExecutionData]).asScala.toSeq + } + + def getOnlineSessionNum: Int = { + store.view(classOf[LiveSessionData]).asScala.count(_.finishTimestamp == 0) + } + + def getSession(sessionId: String): Option[LiveSessionData] = { + try { + Some(store.read(classOf[LiveSessionData], sessionId)) + } catch { + case _: NoSuchElementException => None + } + } + + /** + * When an error or a cancellation occurs, we set the finishTimestamp of the statement. + * Therefore, when we count the number of running statements, we need to exclude errors and + * cancellations and count all statements that have not been closed so far. + */ + def getTotalRunning: Int = { + store.view(classOf[LiveExecutionData]).asScala.count(isExecutionActive) + } + + def isExecutionActive(execInfo: LiveExecutionData): Boolean = { + !(execInfo.state == ExecutionState.FAILED || + execInfo.state == ExecutionState.CANCELED || + execInfo.state == ExecutionState.CLOSED) + } +} + +private[thriftserver] class LiveSessionData( + @KVIndexParam val sessionId: String, + val startTimestamp: Long, + val ip: String, + val userName: String, + val finishTimestamp: Long, + val totalExecution: Long) { + + @JsonIgnore @KVIndexParam("sessionId") + def session: String = sessionId + + def totalTime: Long = { + if (finishTimestamp == 0L) { + System.currentTimeMillis - startTimestamp + } else { + finishTimestamp - startTimestamp + } + } +} + +private[thriftserver] class LiveExecutionData( + @KVIndexParam val execId: String, + val statement: String, + val sessionId: String, + val startTimestamp: Long, + val userName: String, + val finishTimestamp: Long, + val closeTimestamp: Long, + val executePlan: String, + val detail: String, + val state: ExecutionState.Value, + val jobId: ArrayBuffer[String], + val groupId: String) { + + @JsonIgnore @KVIndexParam("execId") + def exec(): String = execId + + def totalTime(endTime: Long): Long = { + if (endTime == 0L) { + System.currentTimeMillis - startTimestamp + } else { + endTime - startTimestamp + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index d22415709860..b5fef5daf50e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -28,7 +28,6 @@ import scala.xml.{Node, Unparsed} import org.apache.commons.text.StringEscapeUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, SessionInfo} import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._ import org.apache.spark.ui._ import org.apache.spark.ui.UIUtils._ @@ -37,18 +36,18 @@ import org.apache.spark.util.Utils /** Page for Spark Web UI that shows statistics of the thrift server */ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("") with Logging { - private val listener = parent.listener + private val store = parent.store private val startTime = Calendar.getInstance().getTime() /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { val content = - listener.synchronized { // make sure all parts in this page are consistent + store.synchronized { // make sure all parts in this page are consistent generateBasicStats() ++
++

- {listener.getOnlineSessionNum} session(s) are online, - running {listener.getTotalRunning} SQL statement(s) + {store.getOnlineSessionNum} session(s) are online, + running {store.getTotalRunning} SQL statement(s)

++ generateSessionStatsTable(request) ++ generateSQLStatsTable(request) @@ -72,7 +71,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = { - val numStatement = listener.getExecutionList.size + val numStatement = store.getExecutionList.size val table = if (numStatement > 0) { @@ -103,7 +102,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" Some(new SqlStatsPagedTable( request, parent, - listener.getExecutionList, + store.getExecutionList, "sqlserver", UIUtils.prependBaseUri(request, parent.basePath), parameterOtherTable, @@ -138,7 +137,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(request: HttpServletRequest): Seq[Node] = { - val numSessions = listener.getSessionList.size + val numSessions = store.getSessionList.size val table = if (numSessions > 0) { val sessionTableTag = "sessionstat" @@ -168,7 +167,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" Some(new SessionStatsPagedTable( request, parent, - listener.getSessionList, + store.getSessionList, "sqlserver", UIUtils.prependBaseUri(request, parent.basePath), parameterOtherTable, @@ -205,7 +204,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" private[ui] class SqlStatsPagedTable( request: HttpServletRequest, parent: ThriftServerTab, - data: Seq[ExecutionInfo], + data: Seq[LiveExecutionData], subPath: String, basePath: String, parameterOtherTable: Iterable[String], @@ -392,14 +391,14 @@ private[ui] class SqlStatsPagedTable( private[ui] class SessionStatsPagedTable( request: HttpServletRequest, parent: ThriftServerTab, - data: Seq[SessionInfo], + data: Seq[LiveSessionData], subPath: String, basePath: String, parameterOtherTable: Iterable[String], sessionStatsTableTag: String, pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[SessionInfo] { + desc: Boolean) extends PagedTable[LiveSessionData] { override val dataSource = new SessionStatsTableDataSource(data, pageSize, sortColumn, desc) @@ -471,7 +470,7 @@ private[ui] class SessionStatsPagedTable( } - override def row(session: SessionInfo): Seq[Node] = { + override def row(session: LiveSessionData): Seq[Node] = { val sessionLink = "%s/%s/session/?id=%s".format( UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId) @@ -490,11 +489,11 @@ private[ui] class SessionStatsPagedTable( val jobId: Seq[String], val duration: Long, val executionTime: Long, - val executionInfo: ExecutionInfo, + val executionInfo: LiveExecutionData, val detail: String) private[ui] class SqlStatsTableDataSource( - info: Seq[ExecutionInfo], + info: Seq[LiveExecutionData], pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[SqlStatsTableRow](pageSize) { @@ -513,7 +512,7 @@ private[ui] class SessionStatsPagedTable( r } - private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = { + private def sqlStatsTableRow(executionInfo: LiveExecutionData): SqlStatsTableRow = { val duration = executionInfo.totalTime(executionInfo.closeTimestamp) val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp) val detail = Option(executionInfo.detail).filter(!_.isEmpty) @@ -552,10 +551,10 @@ private[ui] class SessionStatsPagedTable( } private[ui] class SessionStatsTableDataSource( - info: Seq[SessionInfo], + info: Seq[LiveSessionData], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedDataSource[SessionInfo](pageSize) { + desc: Boolean) extends PagedDataSource[LiveSessionData](pageSize) { // Sorting SessionInfo data private val data = info.sorted(ordering(sortColumn, desc)) @@ -564,7 +563,7 @@ private[ui] class SessionStatsPagedTable( override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[SessionInfo] = { + override def sliceData(from: Int, to: Int): Seq[LiveSessionData] = { val r = data.slice(from, to) _slicedStartTime = r.map(_.startTimestamp).toSet r @@ -573,8 +572,8 @@ private[ui] class SessionStatsPagedTable( /** * Return Ordering according to sortColumn and desc. */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[SessionInfo] = { - val ordering: Ordering[SessionInfo] = sortColumn match { + private def ordering(sortColumn: String, desc: Boolean): Ordering[LiveSessionData] = { + val ordering: Ordering[LiveSessionData] = sortColumn match { case "User" => Ordering.by(_.userName) case "IP" => Ordering.by(_.ip) case "Session ID" => Ordering.by(_.sessionId) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 8b275f8f7be0..839c11f80866 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) extends WebUIPage("session") with Logging { - private val listener = parent.listener + private val store = parent.store private val startTime = Calendar.getInstance().getTime() /** Render the page */ @@ -41,8 +41,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val content = - listener.synchronized { // make sure all parts in this page are consistent - val sessionStat = listener.getSession(parameterId).getOrElse(null) + store.synchronized { // make sure all parts in this page are consistent + val sessionStat = store.getSession(parameterId).getOrElse(null) require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") generateBasicStats() ++ @@ -73,7 +73,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(request: HttpServletRequest, sessionID: String): Seq[Node] = { - val executionList = listener.getExecutionList + val executionList = store.getExecutionList .filter(_.sessionId == sessionID) val numStatement = executionList.size val table = if (numStatement > 0) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 8efb2c3311cf..78434d944c65 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,7 +27,9 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} * Spark Web UI tab that shows statistics of jobs running in the thrift server. * This assumes the given SparkContext has enabled its SparkUI. */ -private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) +private[thriftserver] class ThriftServerTab( + val store: HiveThriftServer2AppStatusStore, + sparkContext: SparkContext) extends SparkUITab(getSparkUI(sparkContext), "sqlserver") with Logging { override val name = "JDBC/ODBC Server" diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2ListenerSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2ListenerSuite.scala new file mode 100644 index 000000000000..eb981f330dac --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2ListenerSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.Properties + +import org.mockito.Mockito.{mock, RETURNS_SMART_NULLS} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.HiveThriftServer2Listener +import org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2AppStatusStore +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.status.ElementTrackingStore +import org.apache.spark.util.kvstore.InMemoryStore + +class HiveThriftServer2ListenerSuite extends SparkFunSuite + with BeforeAndAfter with SharedSparkSession { + + private var kvstore: ElementTrackingStore = _ + + after { + if (kvstore != null) { + kvstore.close() + kvstore = null + } + } + + private def createProperties: Properties = { + val properties = new Properties() + properties.setProperty(SparkContext.SPARK_JOB_GROUP_ID, "groupId") + properties + } + + private def createStatusStore(): HiveThriftServer2AppStatusStore = { + val conf = sqlContext.conf + + kvstore = new ElementTrackingStore(new InMemoryStore, new SparkConf()) + val server = mock(classOf[HiveThriftServer2], RETURNS_SMART_NULLS) + val listener = new HiveThriftServer2Listener(kvstore, server, conf) + new HiveThriftServer2AppStatusStore(kvstore, Some(listener)) + } + + test("listener events should store successfully") { + val statusStore = createStatusStore() + val listener = statusStore.listener.get + + listener.onSessionCreated("localhost", "sessionId", "user") + listener.onStatementStart("id", "sessionId", "dummy query", "groupId", "user") + listener.onStatementParsed("id", "dummy plan") + listener.onJobStart(SparkListenerJobStart( + 0, + System.currentTimeMillis(), + Nil, + createProperties)) + listener.onStatementFinish("id") + listener.onOperationClosed("id") + + assert(statusStore.getOnlineSessionNum == 1) + + listener.onSessionClosed("sessionId") + + assert(statusStore.getOnlineSessionNum == 0) + assert(statusStore.getExecutionList.size == 1) + + val storeExecData = statusStore.getExecutionList.head + + assert(storeExecData.execId == "id") + assert(storeExecData.sessionId == "sessionId") + assert(storeExecData.executePlan == "dummy plan") + assert(storeExecData.jobId == Seq("0")) + } + }