diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index a620f5229237c..1c09cc9f7ff26 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -32,12 +32,11 @@ import org.apache.orc.impl.OrcTail; import org.apache.orc.mapred.OrcInputFormat; +import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns; import org.apache.spark.sql.execution.datasources.orc.OrcShimUtils.VectorizedRowBatchWrap; -import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ConstantColumnVector; -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.*; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -74,11 +73,14 @@ public class OrcColumnarBatchReader extends RecordReader { @VisibleForTesting public ColumnarBatch columnarBatch; + private final MemoryMode memoryMode; + // The wrapped ORC column vectors. private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; - public OrcColumnarBatchReader(int capacity) { + public OrcColumnarBatchReader(int capacity, MemoryMode memoryMode) { this.capacity = capacity; + this.memoryMode = memoryMode; } @@ -186,7 +188,12 @@ public void initBatch( int colId = requestedDataColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + final WritableColumnVector missingCol; + if (memoryMode == MemoryMode.OFF_HEAP) { + missingCol = new OffHeapColumnVector(capacity, dt); + } else { + missingCol = new OnHeapColumnVector(capacity, dt); + } // Check if the missing column has an associated default value in the schema metadata. // If so, fill the corresponding column vector with the value. Object defaultValue = ResolveDefaultColumns.existenceDefaultValues(requiredSchema)[i]; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index e4ae47af79047..5513359fdaa31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -31,6 +31,7 @@ import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce._ import org.apache.spark.TaskContext +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -152,6 +153,12 @@ class OrcFileFormat assert(supportBatch(sparkSession, resultSchema)) } + val memoryMode = if (sqlConf.offHeapColumnVectorEnabled) { + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(hadoopConf, sqlConf.caseSensitiveAnalysis) val broadcastedConf = @@ -196,7 +203,7 @@ class OrcFileFormat val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) if (enableVectorizedReader) { - val batchReader = new OrcColumnarBatchReader(capacity) + val batchReader = new OrcColumnarBatchReader(capacity, memoryMode) // SPARK-23399 Register a task completion listener first to call `close()` in all cases. // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 9543c33b5721a..c44a5d30cafe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -26,6 +26,7 @@ import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.broadcast.Broadcast +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} @@ -58,7 +59,8 @@ case class OrcPartitionReaderFactory( partitionSchema: StructType, filters: Array[Filter], aggregation: Option[Aggregation], - options: OrcOptions) extends FilePartitionReaderFactory { + options: OrcOptions, + memoryMode: MemoryMode) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -147,7 +149,7 @@ case class OrcPartitionReaderFactory( val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) - val batchReader = new OrcColumnarBatchReader(capacity) + val batchReader = new OrcColumnarBatchReader(capacity, memoryMode) batchReader.initialize(fileSplit, taskAttemptContext, readerOptions.getOrcTail) val requestedPartitionColIds = Array.fill(readDataSchema.length)(-1) ++ Range(0, partitionSchema.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 894f7e765a4f4..6bf36c72e0352 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsScala import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation @@ -65,11 +66,16 @@ case class OrcScan( override def createReaderFactory(): PartitionReaderFactory = { val broadcastedConf = sparkSession.sparkContext.broadcast( new SerializableConfiguration(hadoopConf)) + val memoryMode = if (sparkSession.sessionState.conf.offHeapColumnVectorEnabled) { + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, pushedFilters, pushedAggregate, - new OrcOptions(options.asScala.toMap, sparkSession.sessionState.conf)) + new OrcOptions(options.asScala.toMap, sparkSession.sessionState.conf), memoryMode) } override def equals(obj: Any): Boolean = obj match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala index a9389c1c21b40..06ea12f83ce75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala @@ -26,11 +26,12 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.orc.TypeDescription import org.apache.spark.TestUtils +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.vectorized.ConstantColumnVector +import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, OffHeapColumnVector} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -53,7 +54,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { requestedDataColIds: Array[Int], requestedPartitionColIds: Array[Int], resultFields: Array[StructField]): OrcColumnarBatchReader = { - val reader = new OrcColumnarBatchReader(4096) + val reader = new OrcColumnarBatchReader(4096, MemoryMode.ON_HEAP) reader.initBatch( orcFileSchema, resultFields, @@ -117,7 +118,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { val fileSplit = new FileSplit(new Path(file.getCanonicalPath), 0L, file.length, Array.empty) val taskConf = sqlContext.sessionState.newHadoopConf() val orcFileSchema = TypeDescription.fromString(schema.simpleString) - val vectorizedReader = new OrcColumnarBatchReader(4096) + val vectorizedReader = new OrcColumnarBatchReader(4096, MemoryMode.ON_HEAP) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) @@ -148,4 +149,15 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-46598: off-heap mode") { + val reader = new OrcColumnarBatchReader(4096, MemoryMode.OFF_HEAP) + reader.initBatch( + TypeDescription.fromString("struct"), + StructType.fromDDL("col1 int, col2 int, col3 int").fields, + Array(0, 1, -1), + Array(-1, -1, -1), + InternalRow.empty) + assert(reader.columnarBatch.column(2).isInstanceOf[OffHeapColumnVector]) + } }