Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ package org.apache.spark.sql.kyuubi

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters}
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -68,17 +67,20 @@ object SparkDatasetHelper extends Logging {
def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
val schemaCaptured = plan.schema
// TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we
// drop Spark-3.1.x support.
// drop Spark 3.1 support.
val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone
// note that, we can't pass the lazy variable `maxBatchSize` directly, this is because input
// arguments are serialized and sent to the executor side for execution.
val maxBatchSizePerBatch = maxBatchSize
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toBatchIterator(
KyuubiArrowConverters.toBatchIterator(
iter,
schemaCaptured,
maxRecordsPerBatch,
timeZoneId,
context)
maxBatchSizePerBatch,
-1,
timeZoneId)
}
}

Expand Down