Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ import java.util.concurrent.RejectedExecutionException
import scala.collection.JavaConverters._

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
import org.apache.spark.sql.types._

import org.apache.kyuubi.{KyuubiSQLException, Logging}
import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
import org.apache.kyuubi.operation.{ArrayFetchIterator, IterableFetchIterator, OperationState}
import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationState}
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session

Expand Down Expand Up @@ -62,49 +63,49 @@ class ExecuteStatement(
OperationLog.removeCurrentOperationLog()
}

private def executeStatement(): Unit = withLocalProperties {
protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
resultDF.toLocalIterator().asScala
}

protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
resultDF.collect()
}

protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
resultDF.take(maxRows)
}

protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
if (incrementalCollect) {
if (resultMaxRows > 0) {
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
}
info("Execute in incremental collect mode")
new IterableFetchIterator[Any](new Iterable[Any] {
override def iterator: Iterator[Any] = incrementalCollectResult(resultDF)
})
} else {
val internalArray = if (resultMaxRows <= 0) {
info("Execute in full collect mode")
fullCollectResult(resultDF)
} else {
info(s"Execute with max result rows[$resultMaxRows]")
takeResult(resultDF, resultMaxRows)
}
new ArrayFetchIterator(internalArray)
}
}

protected def executeStatement(): Unit = withLocalProperties {
try {
setState(OperationState.RUNNING)
info(diagnostics)
Thread.currentThread().setContextClassLoader(spark.sharedState.jarClassLoader)
addOperationListener()
result = spark.sql(statement)

val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
iter = if (incrementalCollect) {
if (resultMaxRows > 0) {
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
}
info("Execute in incremental collect mode")
def internalIterator(): Iterator[Any] = if (arrowEnabled) {
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).toLocalIterator
} else {
result.toLocalIterator().asScala
}
new IterableFetchIterator[Any](new Iterable[Any] {
override def iterator: Iterator[Any] = internalIterator()
})
} else {
val internalArray = if (resultMaxRows <= 0) {
info("Execute in full collect mode")
if (arrowEnabled) {
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).collect()
} else {
result.collect()
}
} else {
info(s"Execute with max result rows[$resultMaxRows]")
if (arrowEnabled) {
// this will introduce shuffle and hurt performance
val limitedResult = result.limit(resultMaxRows)
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
} else {
result.take(resultMaxRows)
}
}
new ArrayFetchIterator(internalArray)
}
iter = collectAsIterator(result)
setCompiledStateIfNeeded()
setState(OperationState.FINISHED)
} catch {
Expand Down Expand Up @@ -171,3 +172,40 @@ class ExecuteStatement(
s"__kyuubi_operation_result_format__=$resultFormat",
s"__kyuubi_operation_result_arrow_timestampAsString__=$timestampAsString")
}

class ArrowBasedExecuteStatement(
session: Session,
override val statement: String,
override val shouldRunAsync: Boolean,
queryTimeout: Long,
incrementalCollect: Boolean)
extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {

override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).toLocalIterator
}

override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).collect()
}

override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
// this will introduce shuffle and hurt performance
val limitedResult = resultDF.limit(maxRows)
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
}

/**
* assign a new execution id for arrow-based operation.
*/
override protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
SQLExecution.withNewExecutionId(resultDF.queryExecution, Some("collectAsArrow")) {
resultDF.queryExecution.executedPlan.resetMetrics()
super.collectAsIterator(resultDF)
}
}

override protected def isArrowBasedOperation: Boolean = true

override val resultFormat = "arrow"
}
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ abstract class SparkOperation(session: Session)
case FETCH_FIRST => iter.fetchAbsolute(0);
}
resultRowSet =
if (arrowEnabled) {
if (isArrowBasedOperation) {
if (iter.hasNext) {
val taken = iter.next().asInstanceOf[Array[Byte]]
RowSet.toTRowSet(taken, getProtocolVersion)
Expand All @@ -257,8 +257,7 @@ abstract class SparkOperation(session: Session)
RowSet.toTRowSet(
taken.toSeq.asInstanceOf[Seq[Row]],
resultSchema,
getProtocolVersion,
timeZone)
getProtocolVersion)
}
resultRowSet.setStartRowOffset(iter.getPosition)
} catch onError(cancel = true)
Expand All @@ -268,16 +267,9 @@ abstract class SparkOperation(session: Session)

override def shouldRunAsync: Boolean = false

protected def arrowEnabled: Boolean = {
resultFormat.equalsIgnoreCase("arrow") &&
// TODO: (fchen) make all operation support arrow
getClass.getCanonicalName == classOf[ExecuteStatement].getCanonicalName
}
protected def isArrowBasedOperation: Boolean = false

protected def resultFormat: String = {
// TODO: respect the config of the operation ExecuteStatement, if it was set.
spark.conf.get("kyuubi.operation.result.format", "thrift")
}
protected def resultFormat: String = "thrift"

protected def timestampAsString: Boolean = {
spark.conf.get("kyuubi.operation.result.arrow.timestampAsString", "false").toBoolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,24 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
case NoneMode =>
val incrementalCollect = spark.conf.getOption(OPERATION_INCREMENTAL_COLLECT.key)
.map(_.toBoolean).getOrElse(operationIncrementalCollectDefault)
new ExecuteStatement(session, statement, runAsync, queryTimeout, incrementalCollect)
// TODO: respect the config of the operation ExecuteStatement, if it was set.
val resultFormat = spark.conf.get("kyuubi.operation.result.format", "thrift")
resultFormat.toLowerCase match {
case "arrow" =>
new ArrowBasedExecuteStatement(
session,
statement,
runAsync,
queryTimeout,
incrementalCollect)
case _ =>
new ExecuteStatement(
session,
statement,
runAsync,
queryTimeout,
incrementalCollect)
}
case mode =>
new PlanOnlyStatement(session, statement, mode)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.kyuubi.engine.spark.schema

import java.nio.ByteBuffer
import java.time.ZoneId

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -61,16 +60,15 @@ object RowSet {
def toTRowSet(
rows: Seq[Row],
schema: StructType,
protocolVersion: TProtocolVersion,
timeZone: ZoneId): TRowSet = {
protocolVersion: TProtocolVersion): TRowSet = {
if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
toRowBasedSet(rows, schema, timeZone)
toRowBasedSet(rows, schema)
} else {
toColumnBasedSet(rows, schema, timeZone)
toColumnBasedSet(rows, schema)
}
}

def toRowBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
val rowSize = rows.length
val tRows = new java.util.ArrayList[TRow](rowSize)
var i = 0
Expand All @@ -80,7 +78,7 @@ object RowSet {
var j = 0
val columnSize = row.length
while (j < columnSize) {
val columnValue = toTColumnValue(j, row, schema, timeZone)
val columnValue = toTColumnValue(j, row, schema)
tRow.addToColVals(columnValue)
j += 1
}
Expand All @@ -90,21 +88,21 @@ object RowSet {
new TRowSet(0, tRows)
}

def toColumnBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
val rowSize = rows.length
val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
var i = 0
val columnSize = schema.length
while (i < columnSize) {
val field = schema(i)
val tColumn = toTColumn(rows, i, field.dataType, timeZone)
val tColumn = toTColumn(rows, i, field.dataType)
tRowSet.addToColumns(tColumn)
i += 1
}
tRowSet
}

private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType, timeZone: ZoneId): TColumn = {
private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = {
val nulls = new java.util.BitSet()
typ match {
case BooleanType =>
Expand Down Expand Up @@ -186,8 +184,7 @@ object RowSet {
private def toTColumnValue(
ordinal: Int,
row: Row,
types: StructType,
timeZone: ZoneId): TColumnValue = {
types: StructType): TColumnValue = {
types(ordinal).dataType match {
case BooleanType =>
val boolValue = new TBoolValue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ package org.apache.kyuubi.engine.spark.operation

import java.sql.Statement

import org.apache.spark.KyuubiSparkContextHelper
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener

import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine}
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
import org.apache.kyuubi.operation.SparkDataTypeTests

class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests {
Expand Down Expand Up @@ -85,6 +91,34 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
}
}

test("assign a new execution id for arrow-based result") {
var plan: LogicalPlan = null

val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
plan = qe.analyzed
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}
withJdbcStatement() { statement =>
// since all the new sessions have their owner listener bus, we should register the listener
// in the current session.
SparkSQLEngine.currentEngine.get
.backendService
.sessionManager
.allSessions()
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))

val result = statement.executeQuery("select 1 as c1")
assert(result.next())
assert(result.getInt("c1") == 1)
}

KyuubiSparkContextHelper.waitListenerBus(spark)
spark.listenerManager.unregister(listener)
assert(plan.isInstanceOf[Project])
}

private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
val query =
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.schema
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate, ZoneId}
import java.time.{Instant, LocalDate}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -96,10 +96,9 @@ class RowSetSuite extends KyuubiFunSuite {
.add("q", "timestamp")

private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null)))
private val zoneId: ZoneId = ZoneId.systemDefault()

test("column based set") {
val tRowSet = RowSet.toColumnBasedSet(rows, schema, zoneId)
val tRowSet = RowSet.toColumnBasedSet(rows, schema)
assert(tRowSet.getColumns.size() === schema.size)
assert(tRowSet.getRowsSize === 0)

Expand Down Expand Up @@ -204,7 +203,7 @@ class RowSetSuite extends KyuubiFunSuite {
}

test("row based set") {
val tRowSet = RowSet.toRowBasedSet(rows, schema, zoneId)
val tRowSet = RowSet.toRowBasedSet(rows, schema)
assert(tRowSet.getColumnCount === 0)
assert(tRowSet.getRowsSize === rows.size)
val iter = tRowSet.getRowsIterator
Expand Down Expand Up @@ -250,7 +249,7 @@ class RowSetSuite extends KyuubiFunSuite {

test("to row set") {
TProtocolVersion.values().foreach { proto =>
val set = RowSet.toTRowSet(rows, schema, proto, zoneId)
val set = RowSet.toTRowSet(rows, schema, proto)
if (proto.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
assert(!set.isSetColumns, proto.toString)
assert(set.isSetRows, proto.toString)
Expand Down
Loading