Skip to content

Commit 542414e

Browse files
committed
make arrow-based query metrics trackable in SQL UI
1 parent 171473e commit 542414e

File tree

2 files changed

+83
-43
lines changed

2 files changed

+83
-43
lines changed

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ import org.apache.spark.sql.DataFrame
2525
import org.apache.spark.sql.execution.SQLExecution
2626
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
2727
import org.apache.spark.sql.types._
28-
2928
import org.apache.kyuubi.{KyuubiSQLException, Logging}
29+
import org.apache.spark.rdd.RDD
30+
3031
import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS
3132
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
3233
import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationState}
@@ -75,29 +76,6 @@ class ExecuteStatement(
7576
resultDF.take(maxRows)
7677
}
7778

78-
protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
79-
val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
80-
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
81-
if (incrementalCollect) {
82-
if (resultMaxRows > 0) {
83-
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
84-
}
85-
info("Execute in incremental collect mode")
86-
new IterableFetchIterator[Any](new Iterable[Any] {
87-
override def iterator: Iterator[Any] = incrementalCollectResult(resultDF)
88-
})
89-
} else {
90-
val internalArray = if (resultMaxRows <= 0) {
91-
info("Execute in full collect mode")
92-
fullCollectResult(resultDF)
93-
} else {
94-
info(s"Execute with max result rows[$resultMaxRows]")
95-
takeResult(resultDF, resultMaxRows)
96-
}
97-
new ArrayFetchIterator(internalArray)
98-
}
99-
}
100-
10179
protected def executeStatement(): Unit = withLocalProperties {
10280
try {
10381
setState(OperationState.RUNNING)
@@ -163,14 +141,33 @@ class ExecuteStatement(
163141
}
164142
}
165143

166-
def convertComplexType(df: DataFrame): DataFrame = {
167-
SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
168-
}
169-
170144
override def getResultSetMetadataHints(): Seq[String] =
171145
Seq(
172146
s"__kyuubi_operation_result_format__=$resultFormat",
173147
s"__kyuubi_operation_result_arrow_timestampAsString__=$timestampAsString")
148+
149+
private def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
150+
val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
151+
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
152+
if (incrementalCollect) {
153+
if (resultMaxRows > 0) {
154+
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
155+
}
156+
info("Execute in incremental collect mode")
157+
new IterableFetchIterator[Any](new Iterable[Any] {
158+
override def iterator: Iterator[Any] = incrementalCollectResult(resultDF)
159+
})
160+
} else {
161+
val internalArray = if (resultMaxRows <= 0) {
162+
info("Execute in full collect mode")
163+
fullCollectResult(resultDF)
164+
} else {
165+
info(s"Execute with max result rows[$resultMaxRows]")
166+
takeResult(resultDF, resultMaxRows)
167+
}
168+
new ArrayFetchIterator(internalArray)
169+
}
170+
}
174171
}
175172

176173
class ArrowBasedExecuteStatement(
@@ -182,30 +179,36 @@ class ArrowBasedExecuteStatement(
182179
extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {
183180

184181
override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
185-
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).toLocalIterator
182+
collectAsArrow(convertComplexType(resultDF)).toLocalIterator
186183
}
187184

188185
override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
189-
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).collect()
186+
collectAsArrow(convertComplexType(resultDF)).collect()
190187
}
191188

192189
override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
193190
// this will introduce shuffle and hurt performance
194191
val limitedResult = resultDF.limit(maxRows)
195-
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
192+
collectAsArrow(convertComplexType(limitedResult)).collect()
196193
}
197194

198195
/**
199-
* assign a new execution id for arrow-based operation.
196+
* refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
197+
* operation, so that we can track the arrow-based queries on the UI tab.
200198
*/
201-
override protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
202-
SQLExecution.withNewExecutionId(resultDF.queryExecution, Some("collectAsArrow")) {
203-
resultDF.queryExecution.executedPlan.resetMetrics()
204-
super.collectAsIterator(resultDF)
199+
private def collectAsArrow(df: DataFrame): RDD[Array[Byte]] = {
200+
SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
201+
df.queryExecution.executedPlan.resetMetrics()
202+
SparkDatasetHelper.toArrowBatchRdd(df)
205203
}
206204
}
207205

208206
override protected def isArrowBasedOperation: Boolean = true
209207

210208
override val resultFormat = "arrow"
209+
210+
private def convertComplexType(df: DataFrame): DataFrame = {
211+
SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
212+
}
213+
211214
}

externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,22 +103,41 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
103103
withJdbcStatement() { statement =>
104104
// since all the new sessions have their owner listener bus, we should register the listener
105105
// in the current session.
106-
SparkSQLEngine.currentEngine.get
107-
.backendService
108-
.sessionManager
109-
.allSessions()
110-
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
106+
registerListener(listener)
111107

112108
val result = statement.executeQuery("select 1 as c1")
113109
assert(result.next())
114110
assert(result.getInt("c1") == 1)
115111
}
116-
117112
KyuubiSparkContextHelper.waitListenerBus(spark)
118-
spark.listenerManager.unregister(listener)
113+
unregisterListener(listener)
119114
assert(plan.isInstanceOf[Project])
120115
}
121116

117+
test("arrow-based query metrics") {
118+
var queryExecution: QueryExecution = null
119+
120+
val listener = new QueryExecutionListener {
121+
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
122+
queryExecution = qe
123+
}
124+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
125+
}
126+
withJdbcStatement() { statement =>
127+
registerListener(listener)
128+
val result = statement.executeQuery("select 1 as c1")
129+
assert(result.next())
130+
assert(result.getInt("c1") == 1)
131+
}
132+
133+
KyuubiSparkContextHelper.waitListenerBus(spark)
134+
unregisterListener(listener)
135+
136+
val metrics = queryExecution.executedPlan.collectLeaves().head.metrics
137+
assert(metrics.contains("numOutputRows"))
138+
assert(metrics("numOutputRows").value === 1)
139+
}
140+
122141
private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
123142
val query =
124143
s"""
@@ -140,4 +159,22 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
140159
assert(resultSet.next())
141160
assert(resultSet.getString("col") === expect)
142161
}
162+
163+
private def registerListener(listener: QueryExecutionListener): Unit = {
164+
// since all the new sessions have their owner listener bus, we should register the listener
165+
// in the current session.
166+
SparkSQLEngine.currentEngine.get
167+
.backendService
168+
.sessionManager
169+
.allSessions()
170+
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
171+
}
172+
173+
private def unregisterListener(listener: QueryExecutionListener): Unit = {
174+
SparkSQLEngine.currentEngine.get
175+
.backendService
176+
.sessionManager
177+
.allSessions()
178+
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener))
179+
}
143180
}

0 commit comments

Comments
 (0)