diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 33aa21228296d..02815b6978c47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources -import java.io.{FileNotFoundException, IOException} +import java.io.{Closeable, FileNotFoundException, IOException} import org.apache.parquet.io.ParquetDecodingException @@ -85,6 +85,17 @@ class FileScanRDD( private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null + private def resetCurrentIterator(): Unit = { + currentIterator match { + case iter: NextIterator[_] => + iter.closeIfNeeded() + case iter: Closeable => + iter.close() + case _ => // do nothing + } + currentIterator = null + } + def hasNext: Boolean = { // Kill the task in case it has been marked as killed. This logic is from // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order @@ -128,15 +139,21 @@ class FileScanRDD( // Sets InputFileBlockHolder for the file block's information InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) + resetCurrentIterator() if (ignoreMissingFiles || ignoreCorruptFiles) { currentIterator = new NextIterator[Object] { // The readFunction may read some bytes before consuming the iterator, e.g., - // vectorized Parquet reader. Here we use lazy val to delay the creation of - // iterator so that we will throw exception in `getNext`. - private lazy val internalIter = readCurrentFile() + // vectorized Parquet reader. Here we use a lazily initialized variable to delay the + // creation of iterator so that we will throw exception in `getNext`. + private var internalIter: Iterator[InternalRow] = null override def getNext(): AnyRef = { try { + // Initialize `internalIter` lazily. + if (internalIter == null) { + internalIter = readCurrentFile() + } + if (internalIter.hasNext) { internalIter.next() } else { @@ -158,7 +175,13 @@ class FileScanRDD( } } - override def close(): Unit = {} + override def close(): Unit = { + internalIter match { + case iter: Closeable => + iter.close() + case _ => // do nothing + } + } } } else { currentIterator = readCurrentFile() @@ -188,6 +211,7 @@ class FileScanRDD( override def close(): Unit = { incTaskInputMetricsBytesRead() InputFileBlockHolder.unset() + resetCurrentIterator() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala index d8e30e600098d..563337c58474e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala @@ -56,6 +56,13 @@ class RecordReaderIterator[T]( rowReader.getCurrentValue } + override def map[B](f: (T) => B): Iterator[B] with Closeable = + new Iterator[B] with Closeable { + override def hasNext: Boolean = RecordReaderIterator.this.hasNext + override def next(): B = f(RecordReaderIterator.this.next()) + override def close(): Unit = RecordReaderIterator.this.close() + } + override def close(): Unit = { if (rowReader != null) { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 586952aafbbfe..b400e3688f4b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -327,18 +327,31 @@ class ParquetFileFormat int96RebaseMode.toString, enableOffHeapColumnVector && taskContext.isDefined, capacity) + // SPARK-37089: We cannot register a task completion listener to close this iterator here + // because downstream exec nodes have already registered their listeners. Since listeners + // are executed in reverse order of registration, a listener registered here would close the + // iterator while downstream exec nodes are still running. When off-heap column vectors are + // enabled, this can cause a use-after-free bug leading to a segfault. + // + // Instead, we use FileScanRDD's task completion listener to close this iterator. val iter = new RecordReaderIterator(vectorizedReader) - // SPARK-23457 Register a task completion listener before `initialization`. - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - vectorizedReader.initialize(split, hadoopAttemptContext) - logDebug(s"Appending $partitionSchema ${file.partitionValues}") - vectorizedReader.initBatch(partitionSchema, file.partitionValues) - if (returningBatch) { - vectorizedReader.enableReturningBatches() - } + try { + vectorizedReader.initialize(split, hadoopAttemptContext) + logDebug(s"Appending $partitionSchema ${file.partitionValues}") + vectorizedReader.initBatch(partitionSchema, file.partitionValues) + if (returningBatch) { + vectorizedReader.enableReturningBatches() + } - // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. - iter.asInstanceOf[Iterator[InternalRow]] + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] + } catch { + case e: Throwable => + // SPARK-23457: In case there is an exception in initialization, close the iterator to + // avoid leaking resources. + iter.close() + throw e + } } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns InternalRow @@ -354,19 +367,25 @@ class ParquetFileFormat new ParquetRecordReader[InternalRow](readSupport) } val iter = new RecordReaderIterator[InternalRow](reader) - // SPARK-23457 Register a task completion listener before `initialization`. - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - reader.initialize(split, hadoopAttemptContext) - - val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes - val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - - if (partitionSchema.length == 0) { - // There is no partition columns - iter.map(unsafeProjection) - } else { - val joinedRow = new JoinedRow() - iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues))) + try { + reader.initialize(split, hadoopAttemptContext) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + + if (partitionSchema.length == 0) { + // There is no partition columns + iter.map(unsafeProjection) + } else { + val joinedRow = new JoinedRow() + iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues))) + } + } catch { + case e: Throwable => + // SPARK-23457: In case there is an exception in initialization, close the iterator to + // avoid leaking resources. + iter.close() + throw e } } }