From a0a005b7451a166d2b251e50132dac3c1bdfa596 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 11 Jul 2018 09:46:42 -0700 Subject: [PATCH 1/2] SPARK-23325: Use InternalRow when reading with DataSourceV2. This updates the DataSourceV2 API to use InternalRow instead of Row for the default case with no scan mix-ins. Support for readers that produce Row is added through SupportsDeprecatedScanRow, which matches the previous API. Readers that used Row now implement this class and should be migrated to InternalRow. Readers that previously implemented SupportsScanUnsafeRow have been migrated to use no SupportsScan mix-ins and produce InternalRow. --- .../sql/kafka010/KafkaContinuousReader.scala | 16 +++++----- .../sql/kafka010/KafkaMicroBatchReader.scala | 21 ++++++------- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../sources/v2/reader/DataSourceReader.java | 6 ++-- .../v2/reader/InputPartitionReader.java | 7 +++-- ...ow.java => SupportsDeprecatedScanRow.java} | 25 ++++++---------- .../v2/reader/SupportsScanColumnarBatch.java | 4 +-- .../datasources/v2/DataSourceRDD.scala | 1 - .../datasources/v2/DataSourceV2ScanExec.scala | 17 ++++++----- .../continuous/ContinuousDataSourceRDD.scala | 26 ++++++++-------- .../ContinuousQueuedDataReader.scala | 8 ++--- .../ContinuousRateStreamSource.scala | 4 +-- .../sql/execution/streaming/memory.scala | 16 +++++----- .../sources/ContinuousMemoryStream.scala | 7 +++-- .../sources/RateStreamMicroBatchReader.scala | 4 +-- .../execution/streaming/sources/socket.scala | 7 +++-- .../sources/v2/JavaAdvancedDataSourceV2.java | 4 +-- .../v2/JavaPartitionAwareDataSource.java | 4 +-- .../v2/JavaSchemaRequiredDataSource.java | 5 ++-- .../sources/v2/JavaSimpleDataSourceV2.java | 5 ++-- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 9 +++--- .../sources/RateStreamProviderSuite.scala | 6 ++-- .../sql/sources/v2/DataSourceV2Suite.scala | 30 ++++++++++--------- .../sources/v2/SimpleWritableDataSource.scala | 7 +++-- .../sql/streaming/StreamingQuerySuite.scala | 6 ++-- .../ContinuousQueuedDataReaderSuite.scala | 6 ++-- .../sources/StreamingDataSourceV2Suite.scala | 7 +++-- 27 files changed, 133 insertions(+), 127 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{SupportsScanUnsafeRow.java => SupportsDeprecatedScanRow.java} (62%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index badaa69cc303..48b91dfe764e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -26,6 +26,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ @@ -53,7 +54,7 @@ class KafkaContinuousReader( metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { + extends ContinuousReader with Logging { private lazy val session = SparkSession.getActiveSession.get private lazy val sc = session.sparkContext @@ -86,7 +87,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -107,8 +108,8 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => KafkaContinuousInputPartition( - topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[InputPartition[UnsafeRow]] + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss + ): InputPartition[InternalRow] }.asJava } @@ -161,9 +162,10 @@ case class KafkaContinuousInputPartition( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] { - override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = { + override def createContinuousReader( + offset: PartitionOffset): InputPartitionReader[InternalRow] = { val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] require(kafkaOffset.topicPartition == topicPartition, s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") @@ -192,7 +194,7 @@ class KafkaContinuousInputPartitionReader( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 737da2e51b12..6c95b2b2560c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -29,11 +29,12 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -61,7 +62,7 @@ private[kafka010] class KafkaMicroBatchReader( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { + extends MicroBatchReader with Logging { private var startPartitionOffsets: PartitionOffsetMap = _ private var endPartitionOffsets: PartitionOffsetMap = _ @@ -101,7 +102,7 @@ private[kafka010] class KafkaMicroBatchReader( } } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -142,11 +143,11 @@ private[kafka010] class KafkaMicroBatchReader( val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size // Generate factories based on the offset ranges - val factories = offsetRanges.map { range => + offsetRanges.map { range => new KafkaMicroBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) - } - factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer + ): InputPartition[InternalRow] + }.asJava } override def getStartOffset: Offset = { @@ -305,11 +306,11 @@ private[kafka010] case class KafkaMicroBatchInputPartition( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] { + reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] { override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = + override def createPartitionReader(): InputPartitionReader[InternalRow] = new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } @@ -320,7 +321,7 @@ private[kafka010] case class KafkaMicroBatchInputPartitionReader( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging { + reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index c6412eac97db..5d5e57323cff 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -678,7 +678,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) - val factories = reader.planUnsafeInputPartitions().asScala + val factories = reader.planInputPartitions().asScala .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index 36a3e542b5a1..ad9c838992fa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -20,7 +20,7 @@ import java.util.List; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; @@ -43,7 +43,7 @@ * Names of these interfaces start with `SupportsScan`. Note that a reader should only * implement at most one of the special scans, if more than one special scans are implemented, * only one of them would be respected, according to the priority list from high to low: - * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. + * {@link SupportsScanColumnarBatch}, {@link SupportsDeprecatedScanRow}. * * If an exception was throw when applying any of these query optimizations, the action will fail * and no Spark job will be submitted. @@ -76,5 +76,5 @@ public interface DataSourceReader { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - List> planInputPartitions(); + List> planInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index 33fa7be4c1b2..7cf382e52f67 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -26,9 +26,10 @@ * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is * responsible for outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input - * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input - * partition readers that mix in {@link SupportsScanUnsafeRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow} + * for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data + * source readers that mix in {@link SupportsScanColumnarBatch}, or {@link org.apache.spark.sql.Row} + * for data source readers that mix in {@link SupportsDeprecatedScanRow}. */ @InterfaceStability.Evolving public interface InputPartitionReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java similarity index 62% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java index f2220f6d3109..595943cf4d8a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java @@ -17,30 +17,23 @@ package org.apache.spark.sql.sources.v2.reader; -import java.util.List; - import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.InternalRow; + +import java.util.List; /** * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. - * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get - * changed in the future Spark versions. + * interface to output {@link Row} instead of {@link InternalRow}. + * This is an experimental and unstable interface. */ @InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceReader { - - @Override - default List> planInputPartitions() { +public interface SupportsDeprecatedScanRow extends DataSourceReader { + default List> planInputPartitions() { throw new IllegalStateException( - "planInputPartitions not supported by default within SupportsScanUnsafeRow"); + "planInputPartitions not supported by default within SupportsDeprecatedScanRow"); } - /** - * Similar to {@link DataSourceReader#planInputPartitions()}, - * but returns data in unsafe row format. - */ - List> planUnsafeInputPartitions(); + List> planRowInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 0faf81db2460..f4da686740d1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -20,7 +20,7 @@ import java.util.List; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.vectorized.ColumnarBatch; /** @@ -30,7 +30,7 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceReader { @Override - default List> planInputPartitions() { + default List> planInputPartitions() { throw new IllegalStateException( "planInputPartitions not supported by default within SupportsScanColumnarBatch."); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 8d6fb3820d42..7ea53424ae10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c6a7684bf6ab..b030b9a929b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -75,12 +75,13 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala - case _ => - reader.planInputPartitions().asScala.map { - new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow] + private lazy val partitions: Seq[InputPartition[InternalRow]] = reader match { + case r: SupportsDeprecatedScanRow => + r.planRowInputPartitions().asScala.map { + new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[InternalRow] } + case _ => + reader.planInputPartitions().asScala } private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { @@ -132,11 +133,11 @@ case class DataSourceV2ScanExec( } class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType) - extends InputPartition[UnsafeRow] { + extends InputPartition[InternalRow] { override def preferredLocations: Array[String] = partition.preferredLocations - override def createPartitionReader: InputPartitionReader[UnsafeRow] = { + override def createPartitionReader: InputPartitionReader[InternalRow] = { new RowToUnsafeInputPartitionReader( partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind()) } @@ -146,7 +147,7 @@ class RowToUnsafeInputPartitionReader( val rowReader: InputPartitionReader[Row], encoder: ExpressionEncoder[Row]) - extends InputPartitionReader[UnsafeRow] { + extends InputPartitionReader[InternalRow] { override def next: Boolean = rowReader.next diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 73868d5967e9..1ffa1d02f143 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeInputPartitionReader import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset} -import org.apache.spark.util.{NextIterator, ThreadUtils} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader +import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( val index: Int, - val inputPartition: InputPartition[UnsafeRow]) + val inputPartition: InputPartition[InternalRow]) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -51,8 +51,8 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + private val readerInputPartitions: Seq[InputPartition[InternalRow]]) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { readerInputPartitions.zipWithIndex.map { @@ -64,7 +64,7 @@ class ContinuousDataSourceRDD( * Initialize the shared reader for this partition if needed, then read rows from it until * it returns null to signal the end of the epoch. */ - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { // If attempt number isn't 0, this is a task retry, which we don't support. if (context.attemptNumber() != 0) { throw new ContinuousTaskRetryException() @@ -80,8 +80,8 @@ class ContinuousDataSourceRDD( partition.queueReader } - new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = { + new NextIterator[InternalRow] { + override def getNext(): InternalRow = { readerForPartition.next() match { case null => finished = true @@ -101,9 +101,9 @@ class ContinuousDataSourceRDD( object ContinuousDataSourceRDD { private[continuous] def getContinuousReader( - reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = { + reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = { reader match { - case r: ContinuousInputPartitionReader[UnsafeRow] => r + case r: ContinuousInputPartitionReader[InternalRow] => r case wrapped: RowToUnsafeInputPartitionReader => wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]] case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index 8c74b8244d09..bfb87053db47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset import org.apache.spark.util.ThreadUtils @@ -52,7 +52,7 @@ class ContinuousQueuedDataReader( */ sealed trait ContinuousRecord case object EpochMarker extends ContinuousRecord - case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + case class ContinuousRow(row: InternalRow, offset: PartitionOffset) extends ContinuousRecord private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) @@ -79,12 +79,12 @@ class ContinuousQueuedDataReader( } /** - * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * Return the next row to be read in the current epoch, or null if the epoch is done. * * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch * will call next() again to start getting rows. */ - def next(): UnsafeRow = { + def next(): InternalRow = { val POLL_TIMEOUT_MS = 1000 var currentEntry: ContinuousRecord = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 516a563bdcc7..55ce3ae38ee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -35,7 +35,7 @@ case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset class RateStreamContinuousReader(options: DataSourceOptions) - extends ContinuousReader { + extends ContinuousReader with SupportsDeprecatedScanRow { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def getStartOffset(): Offset = offset - override def planInputPartitions(): java.util.List[InputPartition[Row]] = { + override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b137f98045c5..f81abdcc3711 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -28,12 +28,13 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -79,8 +80,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) - with MicroBatchReader with SupportsScanUnsafeRow with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReader with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -139,7 +139,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (endOffset.offset == -1) null else endOffset } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]] + new MemoryStreamInputPartition(block): InputPartition[InternalRow] }.asJava } } @@ -202,9 +202,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) class MemoryStreamInputPartition(records: Array[UnsafeRow]) - extends InputPartition[UnsafeRow] { - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { - new InputPartitionReader[UnsafeRow] { + extends InputPartition[InternalRow] { + override def createPartitionReader(): InputPartitionReader[InternalRow] = { + new InputPartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 0bf90b806332..e776ebc08e30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils @@ -49,7 +49,8 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport + with SupportsDeprecatedScanRow { private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -99,7 +100,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def planInputPartitions(): ju.List[InputPartition[Row]] = { + override def planRowInputPartitions(): ju.List[InputPartition[Row]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index b393c48baee8..7a3452aa315c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ManualClock, SystemClock} class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { + extends MicroBatchReader with SupportsDeprecatedScanRow with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: LongOffset(json.toLong) } - override def planInputPartitions(): java.util.List[InputPartition[Row]] = { + override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = { val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 91e3b7179c34..e3a2c007a9ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.LongOffset import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -50,7 +50,8 @@ object TextSocketMicroBatchReader { * debugging. This MicroBatchReader will *not* work in production applications due to multiple * reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader + with SupportsDeprecatedScanRow with Logging { private var startOffset: Offset = _ private var endOffset: Offset = _ @@ -141,7 +142,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { assert(startOffset != null && endOffset != null, "start offset and end offset should already be set before create read tasks.") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 445cb29f5ee3..c130b5f1e251 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -33,7 +33,7 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, - SupportsPushDownFilters { + SupportsPushDownFilters, SupportsDeprecatedScanRow { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); @@ -79,7 +79,7 @@ public Filter[] pushedFilters() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { List> res = new ArrayList<>(); Integer lowerBound = null; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index e49c8cf8b9e1..35aafb532d80 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -34,7 +34,7 @@ public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsReportPartitioning { + class Reader implements DataSourceReader, SupportsReportPartitioning, SupportsDeprecatedScanRow { private final StructType schema = new StructType().add("a", "int").add("b", "int"); @Override @@ -43,7 +43,7 @@ public StructType readSchema() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { return java.util.Arrays.asList( new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 80eeffd95f83..6dee94c34e21 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -25,11 +25,12 @@ import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.SupportsDeprecatedScanRow; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { - class Reader implements DataSourceReader { + class Reader implements DataSourceReader, SupportsDeprecatedScanRow { private final StructType schema; Reader(StructType schema) { @@ -42,7 +43,7 @@ public StructType readSchema() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 8522a63898a3..5c2f351975c7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -28,11 +28,12 @@ import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.sources.v2.reader.SupportsDeprecatedScanRow; import org.apache.spark.sql.types.StructType; public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader { + class Reader implements DataSourceReader, SupportsDeprecatedScanRow { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -41,7 +42,7 @@ public StructType readSchema() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { return java.util.Arrays.asList( new JavaSimpleInputPartition(0, 5), new JavaSimpleInputPartition(5, 10)); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index 3ad8e7a0104c..25b89c7fd36a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; @@ -29,7 +30,7 @@ public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsScanUnsafeRow { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -38,7 +39,7 @@ public StructType readSchema() { } @Override - public List> planUnsafeInputPartitions() { + public List> planInputPartitions() { return java.util.Arrays.asList( new JavaUnsafeRowInputPartition(0, 5), new JavaUnsafeRowInputPartition(5, 10)); @@ -46,7 +47,7 @@ public List> planUnsafeInputPartitions() { } static class JavaUnsafeRowInputPartition - implements InputPartition, InputPartitionReader { + implements InputPartition, InputPartitionReader { private int start; private int end; private UnsafeRow row; @@ -59,7 +60,7 @@ static class JavaUnsafeRowInputPartition } @Override - public InputPartitionReader createPartitionReader() { + public InputPartitionReader createPartitionReader() { return new JavaUnsafeRowInputPartition(start - 1, end); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 9115a384d079..260a0376daeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -146,7 +146,7 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val tasks = reader.planRowInputPartitions() assert(tasks.size == 1) val dataReader = tasks.get(0).createPartitionReader() val data = ArrayBuffer[Row]() @@ -165,7 +165,7 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val tasks = reader.planRowInputPartitions() assert(tasks.size == 11) val readData = tasks.asScala @@ -311,7 +311,7 @@ class RateSourceSuite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val tasks = reader.planRowInputPartitions() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e96cd4500458..d73eebbc84b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -23,6 +23,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} @@ -344,10 +345,10 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { + class Reader extends DataSourceReader with SupportsDeprecatedScanRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } @@ -357,10 +358,10 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { + class Reader extends DataSourceReader with SupportsDeprecatedScanRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) } } @@ -390,7 +391,7 @@ class SimpleInputPartition(start: Int, end: Int) class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader + class Reader extends DataSourceReader with SupportsDeprecatedScanRow with SupportsPushDownRequiredColumns with SupportsPushDownFilters { var requiredSchema = new StructType().add("i", "int").add("j", "int") @@ -415,7 +416,7 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption @@ -467,10 +468,10 @@ class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsScanUnsafeRow { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planUnsafeInputPartitions(): JList[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5), new UnsafeRowInputPartitionReader(5, 10)) } @@ -480,14 +481,14 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { } class UnsafeRowInputPartitionReader(start: Int, end: Int) - extends InputPartition[UnsafeRow] with InputPartitionReader[UnsafeRow] { + extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) private var current = start - 1 - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = this + override def createPartitionReader(): InputPartitionReader[InternalRow] = this override def next(): Boolean = { current += 1 @@ -504,8 +505,8 @@ class UnsafeRowInputPartitionReader(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - class Reader(val readSchema: StructType) extends DataSourceReader { - override def planInputPartitions(): JList[InputPartition[Row]] = + class Reader(val readSchema: StructType) extends DataSourceReader with SupportsDeprecatedScanRow { + override def planRowInputPartitions(): JList[InputPartition[Row]] = java.util.Collections.emptyList() } @@ -568,10 +569,11 @@ class BatchInputPartitionReader(start: Int, end: Int) class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsReportPartitioning { + class Reader extends DataSourceReader with SupportsReportPartitioning + with SupportsDeprecatedScanRow { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 1334cf71ae98..98d7eedbcb9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -42,10 +42,11 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceReader { + class Reader(path: String, conf: Configuration) extends DataSourceReader + with SupportsDeprecatedScanRow { override def readSchema(): StructType = schema - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 936a076d647b..78199b0a1c19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -30,7 +30,7 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ @@ -227,10 +227,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { clock.waitTillTime(1350) - super.planUnsafeInputPartitions() + super.planInputPartitions() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index 0e7e6febb53d..4f198819b58d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.streaming.continuous import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} -import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.InputPartition @@ -73,8 +73,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new InputPartition[UnsafeRow] { - override def createPartitionReader() = new ContinuousInputPartitionReader[UnsafeRow] { + val factory = new InputPartition[InternalRow] { + override def createPartitionReader() = new ContinuousInputPartitionReader[InternalRow] { var index = -1 var curr: UnsafeRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index c1a28b9bc75e..7c012158bd75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,14 +26,15 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReader() extends MicroBatchReader with ContinuousReader { +case class FakeReader() extends MicroBatchReader with ContinuousReader + with SupportsDeprecatedScanRow { def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} def getStartOffset: Offset = RateStreamOffset(Map()) def getEndOffset: Offset = RateStreamOffset(Map()) @@ -44,7 +45,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setStartOffset(start: Optional[Offset]): Unit = {} - def planInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { + def planRowInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } } From d1fa32e201e73f281a87d46a3510f0e3082c1d35 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 11 Jul 2018 09:53:21 -0700 Subject: [PATCH 2/2] SPARK-23325: Add physical projection in DataSourceV2Strategy. These projections ensure that rows are converted to UnsafeRow before they are passed to physical operators that require UnsafeRow. These operators are rare and both project and filter operators support InternalRow. When push-down is handled during conversion to a physical plan, filters should be placed below a final projection. --- .../datasources/v2/DataSourceV2Strategy.scala | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 2a7f1de2c7c1..9414e68155b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -125,16 +125,13 @@ object DataSourceV2Strategy extends Strategy { val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - val withProjection = if (withFilter.output != project) { - ProjectExec(project, withFilter) - } else { - withFilter - } - - withProjection :: Nil + // always add the projection, which will produce unsafe rows required by some operators + ProjectExec(project, withFilter) :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + // ensure there is a projection, which will produce unsafe rows required by some operators + ProjectExec(r.output, + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil