@@ -25,8 +25,9 @@ import org.apache.spark.sql.DataFrame
2525import org .apache .spark .sql .execution .SQLExecution
2626import org .apache .spark .sql .kyuubi .SparkDatasetHelper
2727import org .apache .spark .sql .types ._
28-
2928import org .apache .kyuubi .{KyuubiSQLException , Logging }
29+ import org .apache .spark .rdd .RDD
30+
3031import org .apache .kyuubi .config .KyuubiConf .OPERATION_RESULT_MAX_ROWS
3132import org .apache .kyuubi .engine .spark .KyuubiSparkUtil ._
3233import org .apache .kyuubi .operation .{ArrayFetchIterator , FetchIterator , IterableFetchIterator , OperationState }
@@ -75,29 +76,6 @@ class ExecuteStatement(
7576 resultDF.take(maxRows)
7677 }
7778
78- protected def collectAsIterator (resultDF : DataFrame ): FetchIterator [_] = {
79- val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS .key).map(_.toInt)
80- .getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS ))
81- if (incrementalCollect) {
82- if (resultMaxRows > 0 ) {
83- warn(s " Ignore ${OPERATION_RESULT_MAX_ROWS .key} on incremental collect mode. " )
84- }
85- info(" Execute in incremental collect mode" )
86- new IterableFetchIterator [Any ](new Iterable [Any ] {
87- override def iterator : Iterator [Any ] = incrementalCollectResult(resultDF)
88- })
89- } else {
90- val internalArray = if (resultMaxRows <= 0 ) {
91- info(" Execute in full collect mode" )
92- fullCollectResult(resultDF)
93- } else {
94- info(s " Execute with max result rows[ $resultMaxRows] " )
95- takeResult(resultDF, resultMaxRows)
96- }
97- new ArrayFetchIterator (internalArray)
98- }
99- }
100-
10179 protected def executeStatement (): Unit = withLocalProperties {
10280 try {
10381 setState(OperationState .RUNNING )
@@ -163,14 +141,33 @@ class ExecuteStatement(
163141 }
164142 }
165143
166- def convertComplexType (df : DataFrame ): DataFrame = {
167- SparkDatasetHelper .convertTopLevelComplexTypeToHiveString(df, timestampAsString)
168- }
169-
170144 override def getResultSetMetadataHints (): Seq [String ] =
171145 Seq (
172146 s " __kyuubi_operation_result_format__= $resultFormat" ,
173147 s " __kyuubi_operation_result_arrow_timestampAsString__= $timestampAsString" )
148+
149+ private def collectAsIterator (resultDF : DataFrame ): FetchIterator [_] = {
150+ val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS .key).map(_.toInt)
151+ .getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS ))
152+ if (incrementalCollect) {
153+ if (resultMaxRows > 0 ) {
154+ warn(s " Ignore ${OPERATION_RESULT_MAX_ROWS .key} on incremental collect mode. " )
155+ }
156+ info(" Execute in incremental collect mode" )
157+ new IterableFetchIterator [Any ](new Iterable [Any ] {
158+ override def iterator : Iterator [Any ] = incrementalCollectResult(resultDF)
159+ })
160+ } else {
161+ val internalArray = if (resultMaxRows <= 0 ) {
162+ info(" Execute in full collect mode" )
163+ fullCollectResult(resultDF)
164+ } else {
165+ info(s " Execute with max result rows[ $resultMaxRows] " )
166+ takeResult(resultDF, resultMaxRows)
167+ }
168+ new ArrayFetchIterator (internalArray)
169+ }
170+ }
174171}
175172
176173class ArrowBasedExecuteStatement (
@@ -182,30 +179,36 @@ class ArrowBasedExecuteStatement(
182179 extends ExecuteStatement (session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {
183180
184181 override protected def incrementalCollectResult (resultDF : DataFrame ): Iterator [Any ] = {
185- SparkDatasetHelper .toArrowBatchRdd (convertComplexType(resultDF)).toLocalIterator
182+ collectAsArrow (convertComplexType(resultDF)).toLocalIterator
186183 }
187184
188185 override protected def fullCollectResult (resultDF : DataFrame ): Array [_] = {
189- SparkDatasetHelper .toArrowBatchRdd (convertComplexType(resultDF)).collect()
186+ collectAsArrow (convertComplexType(resultDF)).collect()
190187 }
191188
192189 override protected def takeResult (resultDF : DataFrame , maxRows : Int ): Array [_] = {
193190 // this will introduce shuffle and hurt performance
194191 val limitedResult = resultDF.limit(maxRows)
195- SparkDatasetHelper .toArrowBatchRdd (convertComplexType(limitedResult)).collect()
192+ collectAsArrow (convertComplexType(limitedResult)).collect()
196193 }
197194
198195 /**
199- * assign a new execution id for arrow-based operation.
196+ * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
197+ * operation, so that we can track the arrow-based queries on the UI tab.
200198 */
201- override protected def collectAsIterator ( resultDF : DataFrame ): FetchIterator [_ ] = {
202- SQLExecution .withNewExecutionId(resultDF .queryExecution, Some (" collectAsArrow" )) {
203- resultDF .queryExecution.executedPlan.resetMetrics()
204- super .collectAsIterator(resultDF )
199+ private def collectAsArrow ( df : DataFrame ): RDD [ Array [ Byte ] ] = {
200+ SQLExecution .withNewExecutionId(df .queryExecution, Some (" collectAsArrow" )) {
201+ df .queryExecution.executedPlan.resetMetrics()
202+ SparkDatasetHelper .toArrowBatchRdd(df )
205203 }
206204 }
207205
208206 override protected def isArrowBasedOperation : Boolean = true
209207
210208 override val resultFormat = " arrow"
209+
210+ private def convertComplexType (df : DataFrame ): DataFrame = {
211+ SparkDatasetHelper .convertTopLevelComplexTypeToHiveString(df, timestampAsString)
212+ }
213+
211214}
0 commit comments