Skip to content

Commit a38bb53

Browse files
cfmcgradypan3793
authored andcommitted
[KYUUBI #4392] [ARROW] Assign a new execution id for arrow-based result
### _Why are the changes needed?_ assign a new execution id for arrow-based result, so that we can track the arrow-based queries on the UI tab. ```sql set kyuubi.operation.result.format=arrow; select 1; ``` Before this PR: ![截屏2023-02-21 下午5 23 08](https://user-images.githubusercontent.com/8537877/220303920-fbaf978b-ead7-4708-9094-bcc84e8fb47c.png) ![截屏2023-02-21 下午5 23 19](https://user-images.githubusercontent.com/8537877/220303966-cb8dfeae-cd10-4c4f-add6-2650619fc5f9.png) After this PR: ![截屏2023-02-22 上午10 21 53](https://user-images.githubusercontent.com/8537877/220504608-f67a5f70-8c64-4e3b-89c2-c2ea54676217.png) ![截屏2023-02-21 下午5 20 50](https://user-images.githubusercontent.com/8537877/220304021-9b845f44-96c3-41f2-a48a-a428f8c4823f.png) ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request Closes #4392 from cfmcgrady/arrow-execution-id-2. Closes #4392 481118a [Fu Chen] enable ut c90674e [Fu Chen] address comment 6cc7af4 [Fu Chen] address comment 3f8a3ab [Fu Chen] fix ut 223a246 [Fu Chen] add KyuubiSparkContextHelper bb7b28f [Fu Chen] fix style 879a150 [Fu Chen] unnecessary changes a2b04f8 [Fu Chen] fix Authored-by: Fu Chen <cfmcgrady@gmail.com> Signed-off-by: Cheng Pan <chengpan@apache.org> (cherry picked from commit f0acff3) Signed-off-by: Cheng Pan <chengpan@apache.org>
1 parent 501c97f commit a38bb53

File tree

8 files changed

+178
-71
lines changed

8 files changed

+178
-71
lines changed

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala

Lines changed: 76 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ import java.util.concurrent.RejectedExecutionException
2222
import scala.collection.JavaConverters._
2323

2424
import org.apache.spark.sql.DataFrame
25+
import org.apache.spark.sql.execution.SQLExecution
2526
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
2627
import org.apache.spark.sql.types._
2728

2829
import org.apache.kyuubi.{KyuubiSQLException, Logging}
2930
import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS
3031
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
31-
import org.apache.kyuubi.operation.{ArrayFetchIterator, IterableFetchIterator, OperationState}
32+
import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationState}
3233
import org.apache.kyuubi.operation.log.OperationLog
3334
import org.apache.kyuubi.session.Session
3435

@@ -62,49 +63,49 @@ class ExecuteStatement(
6263
OperationLog.removeCurrentOperationLog()
6364
}
6465

65-
private def executeStatement(): Unit = withLocalProperties {
66+
protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
67+
resultDF.toLocalIterator().asScala
68+
}
69+
70+
protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
71+
resultDF.collect()
72+
}
73+
74+
protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
75+
resultDF.take(maxRows)
76+
}
77+
78+
protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
79+
val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
80+
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
81+
if (incrementalCollect) {
82+
if (resultMaxRows > 0) {
83+
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
84+
}
85+
info("Execute in incremental collect mode")
86+
new IterableFetchIterator[Any](new Iterable[Any] {
87+
override def iterator: Iterator[Any] = incrementalCollectResult(resultDF)
88+
})
89+
} else {
90+
val internalArray = if (resultMaxRows <= 0) {
91+
info("Execute in full collect mode")
92+
fullCollectResult(resultDF)
93+
} else {
94+
info(s"Execute with max result rows[$resultMaxRows]")
95+
takeResult(resultDF, resultMaxRows)
96+
}
97+
new ArrayFetchIterator(internalArray)
98+
}
99+
}
100+
101+
protected def executeStatement(): Unit = withLocalProperties {
66102
try {
67103
setState(OperationState.RUNNING)
68104
info(diagnostics)
69105
Thread.currentThread().setContextClassLoader(spark.sharedState.jarClassLoader)
70106
addOperationListener()
71107
result = spark.sql(statement)
72-
73-
val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
74-
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
75-
iter = if (incrementalCollect) {
76-
if (resultMaxRows > 0) {
77-
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
78-
}
79-
info("Execute in incremental collect mode")
80-
def internalIterator(): Iterator[Any] = if (arrowEnabled) {
81-
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).toLocalIterator
82-
} else {
83-
result.toLocalIterator().asScala
84-
}
85-
new IterableFetchIterator[Any](new Iterable[Any] {
86-
override def iterator: Iterator[Any] = internalIterator()
87-
})
88-
} else {
89-
val internalArray = if (resultMaxRows <= 0) {
90-
info("Execute in full collect mode")
91-
if (arrowEnabled) {
92-
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).collect()
93-
} else {
94-
result.collect()
95-
}
96-
} else {
97-
info(s"Execute with max result rows[$resultMaxRows]")
98-
if (arrowEnabled) {
99-
// this will introduce shuffle and hurt performance
100-
val limitedResult = result.limit(resultMaxRows)
101-
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
102-
} else {
103-
result.take(resultMaxRows)
104-
}
105-
}
106-
new ArrayFetchIterator(internalArray)
107-
}
108+
iter = collectAsIterator(result)
108109
setCompiledStateIfNeeded()
109110
setState(OperationState.FINISHED)
110111
} catch {
@@ -171,3 +172,40 @@ class ExecuteStatement(
171172
s"__kyuubi_operation_result_format__=$resultFormat",
172173
s"__kyuubi_operation_result_arrow_timestampAsString__=$timestampAsString")
173174
}
175+
176+
class ArrowBasedExecuteStatement(
177+
session: Session,
178+
override val statement: String,
179+
override val shouldRunAsync: Boolean,
180+
queryTimeout: Long,
181+
incrementalCollect: Boolean)
182+
extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {
183+
184+
override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
185+
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).toLocalIterator
186+
}
187+
188+
override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
189+
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).collect()
190+
}
191+
192+
override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
193+
// this will introduce shuffle and hurt performance
194+
val limitedResult = resultDF.limit(maxRows)
195+
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
196+
}
197+
198+
/**
199+
* assign a new execution id for arrow-based operation.
200+
*/
201+
override protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
202+
SQLExecution.withNewExecutionId(resultDF.queryExecution, Some("collectAsArrow")) {
203+
resultDF.queryExecution.executedPlan.resetMetrics()
204+
super.collectAsIterator(resultDF)
205+
}
206+
}
207+
208+
override protected def isArrowBasedOperation: Boolean = true
209+
210+
override val resultFormat = "arrow"
211+
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ abstract class SparkOperation(session: Session)
245245
case FETCH_FIRST => iter.fetchAbsolute(0);
246246
}
247247
resultRowSet =
248-
if (arrowEnabled) {
248+
if (isArrowBasedOperation) {
249249
if (iter.hasNext) {
250250
val taken = iter.next().asInstanceOf[Array[Byte]]
251251
RowSet.toTRowSet(taken, getProtocolVersion)
@@ -257,8 +257,7 @@ abstract class SparkOperation(session: Session)
257257
RowSet.toTRowSet(
258258
taken.toSeq.asInstanceOf[Seq[Row]],
259259
resultSchema,
260-
getProtocolVersion,
261-
timeZone)
260+
getProtocolVersion)
262261
}
263262
resultRowSet.setStartRowOffset(iter.getPosition)
264263
} catch onError(cancel = true)
@@ -268,16 +267,9 @@ abstract class SparkOperation(session: Session)
268267

269268
override def shouldRunAsync: Boolean = false
270269

271-
protected def arrowEnabled: Boolean = {
272-
resultFormat.equalsIgnoreCase("arrow") &&
273-
// TODO: (fchen) make all operation support arrow
274-
getClass.getCanonicalName == classOf[ExecuteStatement].getCanonicalName
275-
}
270+
protected def isArrowBasedOperation: Boolean = false
276271

277-
protected def resultFormat: String = {
278-
// TODO: respect the config of the operation ExecuteStatement, if it was set.
279-
spark.conf.get("kyuubi.operation.result.format", "thrift")
280-
}
272+
protected def resultFormat: String = "thrift"
281273

282274
protected def timestampAsString: Boolean = {
283275
spark.conf.get("kyuubi.operation.result.arrow.timestampAsString", "false").toBoolean

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,24 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
8282
case NoneMode =>
8383
val incrementalCollect = spark.conf.getOption(OPERATION_INCREMENTAL_COLLECT.key)
8484
.map(_.toBoolean).getOrElse(operationIncrementalCollectDefault)
85-
new ExecuteStatement(session, statement, runAsync, queryTimeout, incrementalCollect)
85+
// TODO: respect the config of the operation ExecuteStatement, if it was set.
86+
val resultFormat = spark.conf.get("kyuubi.operation.result.format", "thrift")
87+
resultFormat.toLowerCase match {
88+
case "arrow" =>
89+
new ArrowBasedExecuteStatement(
90+
session,
91+
statement,
92+
runAsync,
93+
queryTimeout,
94+
incrementalCollect)
95+
case _ =>
96+
new ExecuteStatement(
97+
session,
98+
statement,
99+
runAsync,
100+
queryTimeout,
101+
incrementalCollect)
102+
}
86103
case mode =>
87104
new PlanOnlyStatement(session, statement, mode)
88105
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.kyuubi.engine.spark.schema
1919

2020
import java.nio.ByteBuffer
21-
import java.time.ZoneId
2221

2322
import scala.collection.JavaConverters._
2423

@@ -61,16 +60,15 @@ object RowSet {
6160
def toTRowSet(
6261
rows: Seq[Row],
6362
schema: StructType,
64-
protocolVersion: TProtocolVersion,
65-
timeZone: ZoneId): TRowSet = {
63+
protocolVersion: TProtocolVersion): TRowSet = {
6664
if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
67-
toRowBasedSet(rows, schema, timeZone)
65+
toRowBasedSet(rows, schema)
6866
} else {
69-
toColumnBasedSet(rows, schema, timeZone)
67+
toColumnBasedSet(rows, schema)
7068
}
7169
}
7270

73-
def toRowBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
71+
def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
7472
val rowSize = rows.length
7573
val tRows = new java.util.ArrayList[TRow](rowSize)
7674
var i = 0
@@ -80,7 +78,7 @@ object RowSet {
8078
var j = 0
8179
val columnSize = row.length
8280
while (j < columnSize) {
83-
val columnValue = toTColumnValue(j, row, schema, timeZone)
81+
val columnValue = toTColumnValue(j, row, schema)
8482
tRow.addToColVals(columnValue)
8583
j += 1
8684
}
@@ -90,21 +88,21 @@ object RowSet {
9088
new TRowSet(0, tRows)
9189
}
9290

93-
def toColumnBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
91+
def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
9492
val rowSize = rows.length
9593
val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
9694
var i = 0
9795
val columnSize = schema.length
9896
while (i < columnSize) {
9997
val field = schema(i)
100-
val tColumn = toTColumn(rows, i, field.dataType, timeZone)
98+
val tColumn = toTColumn(rows, i, field.dataType)
10199
tRowSet.addToColumns(tColumn)
102100
i += 1
103101
}
104102
tRowSet
105103
}
106104

107-
private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType, timeZone: ZoneId): TColumn = {
105+
private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = {
108106
val nulls = new java.util.BitSet()
109107
typ match {
110108
case BooleanType =>
@@ -186,8 +184,7 @@ object RowSet {
186184
private def toTColumnValue(
187185
ordinal: Int,
188186
row: Row,
189-
types: StructType,
190-
timeZone: ZoneId): TColumnValue = {
187+
types: StructType): TColumnValue = {
191188
types(ordinal).dataType match {
192189
case BooleanType =>
193190
val boolValue = new TBoolValue

externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,14 @@ package org.apache.kyuubi.engine.spark.operation
1919

2020
import java.sql.Statement
2121

22+
import org.apache.spark.KyuubiSparkContextHelper
23+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
24+
import org.apache.spark.sql.execution.QueryExecution
25+
import org.apache.spark.sql.util.QueryExecutionListener
26+
2227
import org.apache.kyuubi.config.KyuubiConf
23-
import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
28+
import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine}
29+
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
2430
import org.apache.kyuubi.operation.SparkDataTypeTests
2531

2632
class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests {
@@ -85,6 +91,34 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
8591
}
8692
}
8793

94+
test("assign a new execution id for arrow-based result") {
95+
var plan: LogicalPlan = null
96+
97+
val listener = new QueryExecutionListener {
98+
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
99+
plan = qe.analyzed
100+
}
101+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
102+
}
103+
withJdbcStatement() { statement =>
104+
// since all the new sessions have their owner listener bus, we should register the listener
105+
// in the current session.
106+
SparkSQLEngine.currentEngine.get
107+
.backendService
108+
.sessionManager
109+
.allSessions()
110+
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
111+
112+
val result = statement.executeQuery("select 1 as c1")
113+
assert(result.next())
114+
assert(result.getInt("c1") == 1)
115+
}
116+
117+
KyuubiSparkContextHelper.waitListenerBus(spark)
118+
spark.listenerManager.unregister(listener)
119+
assert(plan.isInstanceOf[Project])
120+
}
121+
88122
private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
89123
val query =
90124
s"""

externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.schema
2020
import java.nio.ByteBuffer
2121
import java.nio.charset.StandardCharsets
2222
import java.sql.{Date, Timestamp}
23-
import java.time.{Instant, LocalDate, ZoneId}
23+
import java.time.{Instant, LocalDate}
2424

2525
import scala.collection.JavaConverters._
2626

@@ -96,10 +96,9 @@ class RowSetSuite extends KyuubiFunSuite {
9696
.add("q", "timestamp")
9797

9898
private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null)))
99-
private val zoneId: ZoneId = ZoneId.systemDefault()
10099

101100
test("column based set") {
102-
val tRowSet = RowSet.toColumnBasedSet(rows, schema, zoneId)
101+
val tRowSet = RowSet.toColumnBasedSet(rows, schema)
103102
assert(tRowSet.getColumns.size() === schema.size)
104103
assert(tRowSet.getRowsSize === 0)
105104

@@ -204,7 +203,7 @@ class RowSetSuite extends KyuubiFunSuite {
204203
}
205204

206205
test("row based set") {
207-
val tRowSet = RowSet.toRowBasedSet(rows, schema, zoneId)
206+
val tRowSet = RowSet.toRowBasedSet(rows, schema)
208207
assert(tRowSet.getColumnCount === 0)
209208
assert(tRowSet.getRowsSize === rows.size)
210209
val iter = tRowSet.getRowsIterator
@@ -250,7 +249,7 @@ class RowSetSuite extends KyuubiFunSuite {
250249

251250
test("to row set") {
252251
TProtocolVersion.values().foreach { proto =>
253-
val set = RowSet.toTRowSet(rows, schema, proto, zoneId)
252+
val set = RowSet.toTRowSet(rows, schema, proto)
254253
if (proto.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
255254
assert(!set.isSetColumns, proto.toString)
256255
assert(set.isSetRows, proto.toString)

0 commit comments

Comments
 (0)