diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala index 2b90525c1ec..6ebcce37728 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala @@ -22,13 +22,14 @@ import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.kyuubi.SparkDatasetHelper import org.apache.spark.sql.types._ import org.apache.kyuubi.{KyuubiSQLException, Logging} import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._ -import org.apache.kyuubi.operation.{ArrayFetchIterator, IterableFetchIterator, OperationState} +import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationState} import org.apache.kyuubi.operation.log.OperationLog import org.apache.kyuubi.session.Session @@ -62,49 +63,49 @@ class ExecuteStatement( OperationLog.removeCurrentOperationLog() } - private def executeStatement(): Unit = withLocalProperties { + protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = { + resultDF.toLocalIterator().asScala + } + + protected def fullCollectResult(resultDF: DataFrame): Array[_] = { + resultDF.collect() + } + + protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = { + resultDF.take(maxRows) + } + + protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = { + val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt) + .getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS)) + if (incrementalCollect) { + if (resultMaxRows > 0) { + warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.") + } + info("Execute in incremental collect mode") + new IterableFetchIterator[Any](new Iterable[Any] { + override def iterator: Iterator[Any] = incrementalCollectResult(resultDF) + }) + } else { + val internalArray = if (resultMaxRows <= 0) { + info("Execute in full collect mode") + fullCollectResult(resultDF) + } else { + info(s"Execute with max result rows[$resultMaxRows]") + takeResult(resultDF, resultMaxRows) + } + new ArrayFetchIterator(internalArray) + } + } + + protected def executeStatement(): Unit = withLocalProperties { try { setState(OperationState.RUNNING) info(diagnostics) Thread.currentThread().setContextClassLoader(spark.sharedState.jarClassLoader) addOperationListener() result = spark.sql(statement) - - val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt) - .getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS)) - iter = if (incrementalCollect) { - if (resultMaxRows > 0) { - warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.") - } - info("Execute in incremental collect mode") - def internalIterator(): Iterator[Any] = if (arrowEnabled) { - SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).toLocalIterator - } else { - result.toLocalIterator().asScala - } - new IterableFetchIterator[Any](new Iterable[Any] { - override def iterator: Iterator[Any] = internalIterator() - }) - } else { - val internalArray = if (resultMaxRows <= 0) { - info("Execute in full collect mode") - if (arrowEnabled) { - SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).collect() - } else { - result.collect() - } - } else { - info(s"Execute with max result rows[$resultMaxRows]") - if (arrowEnabled) { - // this will introduce shuffle and hurt performance - val limitedResult = result.limit(resultMaxRows) - SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect() - } else { - result.take(resultMaxRows) - } - } - new ArrayFetchIterator(internalArray) - } + iter = collectAsIterator(result) setCompiledStateIfNeeded() setState(OperationState.FINISHED) } catch { @@ -171,3 +172,40 @@ class ExecuteStatement( s"__kyuubi_operation_result_format__=$resultFormat", s"__kyuubi_operation_result_arrow_timestampAsString__=$timestampAsString") } + +class ArrowBasedExecuteStatement( + session: Session, + override val statement: String, + override val shouldRunAsync: Boolean, + queryTimeout: Long, + incrementalCollect: Boolean) + extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) { + + override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = { + SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).toLocalIterator + } + + override protected def fullCollectResult(resultDF: DataFrame): Array[_] = { + SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).collect() + } + + override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = { + // this will introduce shuffle and hurt performance + val limitedResult = resultDF.limit(maxRows) + SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect() + } + + /** + * assign a new execution id for arrow-based operation. + */ + override protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = { + SQLExecution.withNewExecutionId(resultDF.queryExecution, Some("collectAsArrow")) { + resultDF.queryExecution.executedPlan.resetMetrics() + super.collectAsIterator(resultDF) + } + } + + override protected def isArrowBasedOperation: Boolean = true + + override val resultFormat = "arrow" +} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala index a6a7fc896af..eb58407d47c 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala @@ -245,7 +245,7 @@ abstract class SparkOperation(session: Session) case FETCH_FIRST => iter.fetchAbsolute(0); } resultRowSet = - if (arrowEnabled) { + if (isArrowBasedOperation) { if (iter.hasNext) { val taken = iter.next().asInstanceOf[Array[Byte]] RowSet.toTRowSet(taken, getProtocolVersion) @@ -257,8 +257,7 @@ abstract class SparkOperation(session: Session) RowSet.toTRowSet( taken.toSeq.asInstanceOf[Seq[Row]], resultSchema, - getProtocolVersion, - timeZone) + getProtocolVersion) } resultRowSet.setStartRowOffset(iter.getPosition) } catch onError(cancel = true) @@ -268,16 +267,9 @@ abstract class SparkOperation(session: Session) override def shouldRunAsync: Boolean = false - protected def arrowEnabled: Boolean = { - resultFormat.equalsIgnoreCase("arrow") && - // TODO: (fchen) make all operation support arrow - getClass.getCanonicalName == classOf[ExecuteStatement].getCanonicalName - } + protected def isArrowBasedOperation: Boolean = false - protected def resultFormat: String = { - // TODO: respect the config of the operation ExecuteStatement, if it was set. - spark.conf.get("kyuubi.operation.result.format", "thrift") - } + protected def resultFormat: String = "thrift" protected def timestampAsString: Boolean = { spark.conf.get("kyuubi.operation.result.arrow.timestampAsString", "false").toBoolean diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala index 5c5ed0f9868..4743f147c4c 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala @@ -82,7 +82,24 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n case NoneMode => val incrementalCollect = spark.conf.getOption(OPERATION_INCREMENTAL_COLLECT.key) .map(_.toBoolean).getOrElse(operationIncrementalCollectDefault) - new ExecuteStatement(session, statement, runAsync, queryTimeout, incrementalCollect) + // TODO: respect the config of the operation ExecuteStatement, if it was set. + val resultFormat = spark.conf.get("kyuubi.operation.result.format", "thrift") + resultFormat.toLowerCase match { + case "arrow" => + new ArrowBasedExecuteStatement( + session, + statement, + runAsync, + queryTimeout, + incrementalCollect) + case _ => + new ExecuteStatement( + session, + statement, + runAsync, + queryTimeout, + incrementalCollect) + } case mode => new PlanOnlyStatement(session, statement, mode) } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala index 7be70403d5d..4f935ce49f0 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala @@ -18,7 +18,6 @@ package org.apache.kyuubi.engine.spark.schema import java.nio.ByteBuffer -import java.time.ZoneId import scala.collection.JavaConverters._ @@ -61,16 +60,15 @@ object RowSet { def toTRowSet( rows: Seq[Row], schema: StructType, - protocolVersion: TProtocolVersion, - timeZone: ZoneId): TRowSet = { + protocolVersion: TProtocolVersion): TRowSet = { if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { - toRowBasedSet(rows, schema, timeZone) + toRowBasedSet(rows, schema) } else { - toColumnBasedSet(rows, schema, timeZone) + toColumnBasedSet(rows, schema) } } - def toRowBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = { + def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = { val rowSize = rows.length val tRows = new java.util.ArrayList[TRow](rowSize) var i = 0 @@ -80,7 +78,7 @@ object RowSet { var j = 0 val columnSize = row.length while (j < columnSize) { - val columnValue = toTColumnValue(j, row, schema, timeZone) + val columnValue = toTColumnValue(j, row, schema) tRow.addToColVals(columnValue) j += 1 } @@ -90,21 +88,21 @@ object RowSet { new TRowSet(0, tRows) } - def toColumnBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = { + def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = { val rowSize = rows.length val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize)) var i = 0 val columnSize = schema.length while (i < columnSize) { val field = schema(i) - val tColumn = toTColumn(rows, i, field.dataType, timeZone) + val tColumn = toTColumn(rows, i, field.dataType) tRowSet.addToColumns(tColumn) i += 1 } tRowSet } - private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType, timeZone: ZoneId): TColumn = { + private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = { val nulls = new java.util.BitSet() typ match { case BooleanType => @@ -186,8 +184,7 @@ object RowSet { private def toTColumnValue( ordinal: Int, row: Row, - types: StructType, - timeZone: ZoneId): TColumnValue = { + types: StructType): TColumnValue = { types(ordinal).dataType match { case BooleanType => val boolValue = new TBoolValue diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index 60cc528912d..30cdeca5abe 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -19,8 +19,14 @@ package org.apache.kyuubi.engine.spark.operation import java.sql.Statement +import org.apache.spark.KyuubiSparkContextHelper +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + import org.apache.kyuubi.config.KyuubiConf -import org.apache.kyuubi.engine.spark.WithSparkSQLEngine +import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine} +import org.apache.kyuubi.engine.spark.session.SparkSessionImpl import org.apache.kyuubi.operation.SparkDataTypeTests class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests { @@ -85,6 +91,34 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp } } + test("assign a new execution id for arrow-based result") { + var plan: LogicalPlan = null + + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.analyzed + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + withJdbcStatement() { statement => + // since all the new sessions have their owner listener bus, we should register the listener + // in the current session. + SparkSQLEngine.currentEngine.get + .backendService + .sessionManager + .allSessions() + .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener)) + + val result = statement.executeQuery("select 1 as c1") + assert(result.next()) + assert(result.getInt("c1") == 1) + } + + KyuubiSparkContextHelper.waitListenerBus(spark) + spark.listenerManager.unregister(listener) + assert(plan.isInstanceOf[Project]) + } + private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = { val query = s""" diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala index a999563ea49..5d2ba4a0d11 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala @@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.schema import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.{Instant, LocalDate, ZoneId} +import java.time.{Instant, LocalDate} import scala.collection.JavaConverters._ @@ -96,10 +96,9 @@ class RowSetSuite extends KyuubiFunSuite { .add("q", "timestamp") private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null))) - private val zoneId: ZoneId = ZoneId.systemDefault() test("column based set") { - val tRowSet = RowSet.toColumnBasedSet(rows, schema, zoneId) + val tRowSet = RowSet.toColumnBasedSet(rows, schema) assert(tRowSet.getColumns.size() === schema.size) assert(tRowSet.getRowsSize === 0) @@ -204,7 +203,7 @@ class RowSetSuite extends KyuubiFunSuite { } test("row based set") { - val tRowSet = RowSet.toRowBasedSet(rows, schema, zoneId) + val tRowSet = RowSet.toRowBasedSet(rows, schema) assert(tRowSet.getColumnCount === 0) assert(tRowSet.getRowsSize === rows.size) val iter = tRowSet.getRowsIterator @@ -250,7 +249,7 @@ class RowSetSuite extends KyuubiFunSuite { test("to row set") { TProtocolVersion.values().foreach { proto => - val set = RowSet.toTRowSet(rows, schema, proto, zoneId) + val set = RowSet.toTRowSet(rows, schema, proto) if (proto.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { assert(!set.isSetColumns, proto.toString) assert(set.isSetRows, proto.toString) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala new file mode 100644 index 00000000000..8293123ead7 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.sql.SparkSession + +/** + * A place to invoke non-public APIs of [[SparkContext]], for test only. + */ +object KyuubiSparkContextHelper { + + def waitListenerBus(spark: SparkSession): Unit = { + spark.sparkContext.listenerBus.waitUntilEmpty() + } +} diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala index e297e6281ae..a42b05473a7 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala @@ -433,13 +433,13 @@ trait SparkQueryTests extends SparkDataTypeTests with HiveJDBCTestHelper { expectedFormat = "thrift") checkStatusAndResultSetFormatHint( sql = "set kyuubi.operation.result.format=arrow", - expectedFormat = "arrow") + expectedFormat = "thrift") checkStatusAndResultSetFormatHint( sql = "SELECT 1", expectedFormat = "arrow") checkStatusAndResultSetFormatHint( sql = "set kyuubi.operation.result.format=thrift", - expectedFormat = "thrift") + expectedFormat = "arrow") checkStatusAndResultSetFormatHint( sql = "set kyuubi.operation.result.format", expectedFormat = "thrift")