Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.spark

/**
* Borrowed from Apache Spark, see SPARK-33655
*/
private[engine] sealed trait FetchIterator[A] extends Iterator[A] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's honor Spark libraries and add some comment about it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

/**
* Begin a fetch block, forward from the current position.
* Resets the fetch start offset.
*/
def fetchNext(): Unit

/**
* Begin a fetch block, moving the iterator back by offset from the start of the previous fetch
* block start.
* Resets the fetch start offset.
*
* @param offset the amount to move a fetch start position toward the prior direction.
*/
def fetchPrior(offset: Long): Unit = fetchAbsolute(getFetchStart - offset)

/**
* Begin a fetch block, moving the iterator to the given position.
* Resets the fetch start offset.
*
* @param pos index to move a position of iterator.
*/
def fetchAbsolute(pos: Long): Unit

def getFetchStart: Long

def getPosition: Long
}

private[engine] class ArrayFetchIterator[A](src: Array[A]) extends FetchIterator[A] {
private var fetchStart: Long = 0

private var position: Long = 0

override def fetchNext(): Unit = fetchStart = position

override def fetchAbsolute(pos: Long): Unit = {
position = (pos max 0) min src.length
fetchStart = position
}

override def getFetchStart: Long = fetchStart

override def getPosition: Long = position

override def hasNext: Boolean = position < src.length

override def next(): A = {
position += 1
src(position.toInt - 1)
}
}

private[engine] class IterableFetchIterator[A](iterable: Iterable[A]) extends FetchIterator[A] {
private var iter: Iterator[A] = iterable.iterator

private var fetchStart: Long = 0

private var position: Long = 0

override def fetchNext(): Unit = fetchStart = position

override def fetchAbsolute(pos: Long): Unit = {
val newPos = pos max 0
if (newPos < position) resetPosition()
while (position < newPos && hasNext) next()
fetchStart = position
}

override def getFetchStart: Long = fetchStart

override def getPosition: Long = position

override def hasNext: Boolean = iter.hasNext

override def next(): A = {
position += 1
iter.next()
}

private def resetPosition(): Unit = {
if (position != 0) {
iter = iterable.iterator
position = 0
fetchStart = 0
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

import org.apache.kyuubi.{KyuubiSQLException, Logging}
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil
import org.apache.kyuubi.engine.spark.{ArrayFetchIterator, KyuubiSparkUtil}
import org.apache.kyuubi.operation.{OperationState, OperationType}
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session
Expand Down Expand Up @@ -74,7 +74,7 @@ class ExecuteStatement(
debug(s"original result queryExecution: ${result.queryExecution}")
val castedResult = result.select(castCols: _*)
debug(s"casted result queryExecution: ${castedResult.queryExecution}")
iter = castedResult.collect().toList.iterator
iter = new ArrayFetchIterator(castedResult.collect())
setState(OperationState.FINISHED)
} catch {
onError(cancel = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.TABLE_CAT
Expand All @@ -35,7 +36,7 @@ class GetCatalogs(spark: SparkSession, session: Session)

override protected def runInternal(): Unit = {
try {
iter = SparkCatalogShim().getCatalogs(spark).toIterator
iter = new IterableFetchIterator(SparkCatalogShim().getCatalogs(spark).toList)
} catch onError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
Expand Down Expand Up @@ -88,9 +89,8 @@ class GetColumns(
val schemaPattern = toJavaRegex(schemaName)
val tablePattern = toJavaRegex(tableName)
val columnPattern = toJavaRegex(columnName)
iter = SparkCatalogShim()
.getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern)
.toList.iterator
iter = new IterableFetchIterator(SparkCatalogShim()
.getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern).toList)
} catch {
onError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.DatabaseMetaData
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
Expand Down Expand Up @@ -70,7 +71,7 @@ class GetFunctions(
info.getClassName)
}
}
iter = a.toList.iterator
iter = new IterableFetchIterator(a.toList)
} catch {
onError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
Expand All @@ -42,7 +43,7 @@ class GetSchemas(spark: SparkSession, session: Session, catalogName: String, sch
try {
val schemaPattern = toJavaRegex(schema)
val rows = SparkCatalogShim().getSchemas(spark, catalogName, schemaPattern)
iter = rows.toList.toIterator
iter = new IterableFetchIterator(rows)
} catch onError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
Expand All @@ -33,6 +34,6 @@ class GetTableTypes(spark: SparkSession, session: Session)
}

override protected def runInternal(): Unit = {
iter = SparkCatalogShim.sparkTableTypes.map(Row(_)).toList.iterator
iter = new IterableFetchIterator(SparkCatalogShim.sparkTableTypes.map(Row(_)).toList)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
Expand Down Expand Up @@ -73,7 +74,7 @@ class GetTables(
} else {
catalogTablesAndViews
}
iter = allTableAndViews.toList.iterator
iter = new IterableFetchIterator(allTableAndViews)
} catch {
onError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.Types._
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
Expand Down Expand Up @@ -83,7 +84,7 @@ class GetTypeInfo(spark: SparkSession, session: Session)
}

override protected def runInternal(): Unit = {
iter = Seq(
iter = new IterableFetchIterator(Seq(
toRow("VOID", NULL),
toRow("BOOLEAN", BOOLEAN),
toRow("TINYINT", TINYINT, 3),
Expand All @@ -101,6 +102,6 @@ class GetTypeInfo(spark: SparkSession, session: Session)
toRow("MAP", JAVA_OBJECT),
toRow("STRUCT", STRUCT),
toRow("INTERVAL", OTHER)
).toList.iterator
))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.engine.spark.FetchIterator
import org.apache.kyuubi.operation.{AbstractOperation, OperationState}
import org.apache.kyuubi.operation.FetchOrientation.FetchOrientation
import org.apache.kyuubi.operation.FetchOrientation._
import org.apache.kyuubi.operation.OperationState.OperationState
import org.apache.kyuubi.operation.OperationType.OperationType
import org.apache.kyuubi.operation.log.OperationLog
Expand All @@ -36,7 +37,7 @@ import org.apache.kyuubi.session.Session
abstract class SparkOperation(spark: SparkSession, opType: OperationType, session: Session)
extends AbstractOperation(opType, session) {

protected var iter: Iterator[Row] = _
protected var iter: FetchIterator[Row] = _

protected final val operationLog: OperationLog =
OperationLog.createOperationLog(session.handle, getHandle)
Expand Down Expand Up @@ -130,8 +131,15 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio
validateDefaultFetchOrientation(order)
assertState(OperationState.FINISHED)
setHasResultSet(true)
order match {
case FETCH_NEXT => iter.fetchNext()
case FETCH_PRIOR => iter.fetchPrior(rowSetSize);
case FETCH_FIRST => iter.fetchAbsolute(0);
}
val taken = iter.take(rowSetSize)
RowSet.toTRowSet(taken.toList, resultSchema, getProtocolVersion)
val resultRowSet = RowSet.toTRowSet(taken.toList, resultSchema, getProtocolVersion)
resultRowSet.setStartRowOffset(iter.getPosition)
resultRowSet
}

override def shouldRunAsync: Boolean = false
Expand Down
Loading