diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 10b1783242e..f9e4deae790 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.kyuubi import scala.collection.mutable.ArrayBuffer -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters} +import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -68,17 +67,20 @@ object SparkDatasetHelper extends Logging { def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = plan.schema // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we - // drop Spark-3.1.x support. + // drop Spark 3.1 support. val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone + // note that, we can't pass the lazy variable `maxBatchSize` directly, this is because input + // arguments are serialized and sent to the executor side for execution. + val maxBatchSizePerBatch = maxBatchSize plan.execute().mapPartitionsInternal { iter => - val context = TaskContext.get() - ArrowConverters.toBatchIterator( + KyuubiArrowConverters.toBatchIterator( iter, schemaCaptured, maxRecordsPerBatch, - timeZoneId, - context) + maxBatchSizePerBatch, + -1, + timeZoneId) } }