Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@
import org.apache.orc.impl.OrcTail;
import org.apache.orc.mapred.OrcInputFormat;

import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns;
import org.apache.spark.sql.execution.datasources.orc.OrcShimUtils.VectorizedRowBatchWrap;
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.*;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.vectorized.ColumnarBatch;

Expand Down Expand Up @@ -74,11 +73,14 @@ public class OrcColumnarBatchReader extends RecordReader<Void, ColumnarBatch> {
@VisibleForTesting
public ColumnarBatch columnarBatch;

private final MemoryMode memoryMode;

// The wrapped ORC column vectors.
private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers;

public OrcColumnarBatchReader(int capacity) {
public OrcColumnarBatchReader(int capacity, MemoryMode memoryMode) {
this.capacity = capacity;
this.memoryMode = memoryMode;
}


Expand Down Expand Up @@ -186,7 +188,12 @@ public void initBatch(
int colId = requestedDataColIds[i];
// Initialize the missing columns once.
if (colId == -1) {
OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt);
final WritableColumnVector missingCol;
Copy link
Contributor

@LuciferYang LuciferYang Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use ConstantColumnVector for the missingCol? This maybe another story.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems simpler to use ConstantColumnVector here. I've updated the PR

if (memoryMode == MemoryMode.OFF_HEAP) {
missingCol = new OffHeapColumnVector(capacity, dt);
} else {
missingCol = new OnHeapColumnVector(capacity, dt);
}
// Check if the missing column has an associated default value in the schema metadata.
// If so, fill the corresponding column vector with the value.
Object defaultValue = ResolveDefaultColumns.existenceDefaultValues(requiredSchema)[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.orc.mapred.OrcStruct
import org.apache.orc.mapreduce._

import org.apache.spark.TaskContext
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -152,6 +153,12 @@ class OrcFileFormat
assert(supportBatch(sparkSession, resultSchema))
}

val memoryMode = if (sqlConf.offHeapColumnVectorEnabled) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it outside of the lambda, so that we don't hit NPE by referencing sqlConf.

MemoryMode.OFF_HEAP
} else {
MemoryMode.ON_HEAP
}

OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(hadoopConf, sqlConf.caseSensitiveAnalysis)

val broadcastedConf =
Expand Down Expand Up @@ -196,7 +203,7 @@ class OrcFileFormat
val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)

if (enableVectorizedReader) {
val batchReader = new OrcColumnarBatchReader(capacity)
val batchReader = new OrcColumnarBatchReader(capacity, memoryMode)
// SPARK-23399 Register a task completion listener first to call `close()` in all cases.
// There is a possibility that `initialize` and `initBatch` hit some errors (like OOM)
// after opening a file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.orc.mapred.OrcStruct
import org.apache.orc.mapreduce.OrcInputFormat

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
Expand Down Expand Up @@ -58,7 +59,8 @@ case class OrcPartitionReaderFactory(
partitionSchema: StructType,
filters: Array[Filter],
aggregation: Option[Aggregation],
options: OrcOptions) extends FilePartitionReaderFactory {
options: OrcOptions,
memoryMode: MemoryMode) extends FilePartitionReaderFactory {
private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields)
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val capacity = sqlConf.orcVectorizedReaderBatchSize
Expand Down Expand Up @@ -147,7 +149,7 @@ case class OrcPartitionReaderFactory(
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)

val batchReader = new OrcColumnarBatchReader(capacity)
val batchReader = new OrcColumnarBatchReader(capacity, memoryMode)
batchReader.initialize(fileSplit, taskAttemptContext, readerOptions.getOrcTail)
val requestedPartitionColIds =
Array.fill(readDataSchema.length)(-1) ++ Range(0, partitionSchema.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsScala
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
Expand Down Expand Up @@ -65,11 +66,16 @@ case class OrcScan(
override def createReaderFactory(): PartitionReaderFactory = {
val broadcastedConf = sparkSession.sparkContext.broadcast(
new SerializableConfiguration(hadoopConf))
val memoryMode = if (sparkSession.sessionState.conf.offHeapColumnVectorEnabled) {
MemoryMode.OFF_HEAP
} else {
MemoryMode.ON_HEAP
}
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema, pushedFilters, pushedAggregate,
new OrcOptions(options.asScala.toMap, sparkSession.sessionState.conf))
new OrcOptions(options.asScala.toMap, sparkSession.sessionState.conf), memoryMode)
}

override def equals(obj: Any): Boolean = obj match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.orc.TypeDescription

import org.apache.spark.TestUtils
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, OffHeapColumnVector}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -53,7 +54,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession {
requestedDataColIds: Array[Int],
requestedPartitionColIds: Array[Int],
resultFields: Array[StructField]): OrcColumnarBatchReader = {
val reader = new OrcColumnarBatchReader(4096)
val reader = new OrcColumnarBatchReader(4096, MemoryMode.ON_HEAP)
reader.initBatch(
orcFileSchema,
resultFields,
Expand Down Expand Up @@ -117,7 +118,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession {
val fileSplit = new FileSplit(new Path(file.getCanonicalPath), 0L, file.length, Array.empty)
val taskConf = sqlContext.sessionState.newHadoopConf()
val orcFileSchema = TypeDescription.fromString(schema.simpleString)
val vectorizedReader = new OrcColumnarBatchReader(4096)
val vectorizedReader = new OrcColumnarBatchReader(4096, MemoryMode.ON_HEAP)
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)

Expand Down Expand Up @@ -148,4 +149,15 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession {
}
}
}

test("SPARK-46598: off-heap mode") {
val reader = new OrcColumnarBatchReader(4096, MemoryMode.OFF_HEAP)
reader.initBatch(
TypeDescription.fromString("struct<col1:int,col2:int>"),
StructType.fromDDL("col1 int, col2 int, col3 int").fields,
Array(0, 1, -1),
Array(-1, -1, -1),
InternalRow.empty)
assert(reader.columnarBatch.column(2).isInstanceOf[OffHeapColumnVector])
}
}