diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala similarity index 74% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 1753a28fba2f..8ce56a249622 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -25,15 +25,16 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.TaskContext import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType /** - * A [[ContinuousReadSupport]] for data from kafka. + * A [[ContinuousReader]] for data from kafka. * * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. @@ -46,49 +47,70 @@ import org.apache.spark.sql.types.StructType * scenarios, where some offsets after the specified initial ones can't be * properly read. */ -class KafkaContinuousReadSupport( +class KafkaContinuousReader( offsetReader: KafkaOffsetReader, kafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReadSupport with Logging { + extends ContinuousReader with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong - override def initialOffset(): Offset = { - val offsets = initialOffsets match { - case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) - } - logInfo(s"Initial offsets: $offsets") - offsets - } + // Initialized when creating reader factories. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ - override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new KafkaContinuousScanConfigBuilder(fullSchema(), start, offsetReader, reportDataLoss) + private var offset: Offset = _ + override def setStartOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } } + override def getStartOffset(): Offset = offset + override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffsets = config.asInstanceOf[KafkaContinuousScanConfig].startOffsets + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + startOffsets.toSeq.map { case (topicPartition, start) => KafkaContinuousInputPartition( - topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - }.toArray - } - - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - KafkaContinuousReaderFactory + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss + ): InputPartition[InternalRow] + }.asJava } /** Stop this source and free any resources it has allocated. */ @@ -105,9 +127,8 @@ class KafkaContinuousReadSupport( KafkaSourceOffset(mergedMap) } - override def needsReconfiguration(config: ScanConfig): Boolean = { - val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions - offsetReader.fetchLatestOffsets().keySet != knownPartitions + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions } override def toString(): String = s"KafkaSource[$offsetReader]" @@ -141,51 +162,23 @@ case class KafkaContinuousInputPartition( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends InputPartition - -object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory { - override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { - val p = partition.asInstanceOf[KafkaContinuousInputPartition] - new KafkaContinuousPartitionReader( - p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss) + failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] { + + override def createContinuousReader( + offset: PartitionOffset): InputPartitionReader[InternalRow] = { + val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] + require(kafkaOffset.topicPartition == topicPartition, + s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") + new KafkaContinuousInputPartitionReader( + topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } -} - -class KafkaContinuousScanConfigBuilder( - schema: StructType, - startOffset: Offset, - offsetReader: KafkaOffsetReader, - reportDataLoss: String => Unit) - extends ScanConfigBuilder { - - override def build(): ScanConfig = { - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(startOffset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - KafkaContinuousScanConfig(schema, startOffsets) + override def createPartitionReader(): KafkaContinuousInputPartitionReader = { + new KafkaContinuousInputPartitionReader( + topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } } -case class KafkaContinuousScanConfig( - readSchema: StructType, - startOffsets: Map[TopicPartition, Long]) - extends ScanConfig { - - // Created when building the scan config builder. If this diverges from the partitions at the - // latest offsets, we need to reconfigure the kafka read support. - def knownPartitions: Set[TopicPartition] = startOffsets.keySet -} - /** * A per-task data reader for continuous Kafka processing. * @@ -196,12 +189,12 @@ case class KafkaContinuousScanConfig( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -class KafkaContinuousPartitionReader( +class KafkaContinuousInputPartitionReader( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala similarity index 84% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index bb4de674c3c7..8cc989fce197 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -21,6 +21,8 @@ import java.{util => ju} import java.io._ import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ + import org.apache.commons.io.IOUtils import org.apache.spark.SparkEnv @@ -29,17 +31,16 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} -import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread /** - * A [[MicroBatchReadSupport]] that reads data from Kafka. + * A [[MicroBatchReader]] that reads data from Kafka. * * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For @@ -54,13 +55,17 @@ import org.apache.spark.util.UninterruptibleThread * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * and not use wrong broker addresses. */ -private[kafka010] class KafkaMicroBatchReadSupport( +private[kafka010] class KafkaMicroBatchReader( kafkaOffsetReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], options: DataSourceOptions, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) extends RateControlMicroBatchReadSupport with Logging { + failOnDataLoss: Boolean) + extends MicroBatchReader with Logging { + + private var startPartitionOffsets: PartitionOffsetMap = _ + private var endPartitionOffsets: PartitionOffsetMap = _ private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", @@ -70,40 +75,34 @@ private[kafka010] class KafkaMicroBatchReadSupport( Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) private val rangeCalculator = KafkaOffsetRangeCalculator(options) - - private var endPartitionOffsets: KafkaSourceOffset = _ - /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running * `KafkaConsumer.poll` may hang forever (KAFKA-1894). */ - override def initialOffset(): Offset = { - KafkaSourceOffset(getOrCreateInitialPartitionOffsets()) - } - - override def latestOffset(start: Offset): Offset = { - val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets - val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() - endPartitionOffsets = KafkaSourceOffset(maxOffsetsPerTrigger.map { maxOffsets => - rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) - }.getOrElse { - latestPartitionOffsets - }) - endPartitionOffsets - } - - override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema - - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() + + override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + startPartitionOffsets = Option(start.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse(initialPartitionOffsets) + + endPartitionOffsets = Option(end.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse { + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + } + } } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startPartitionOffsets = sc.start.asInstanceOf[KafkaSourceOffset].partitionToOffsets - val endPartitionOffsets = sc.end.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets - + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -145,19 +144,26 @@ private[kafka010] class KafkaMicroBatchReadSupport( // Generate factories based on the offset ranges offsetRanges.map { range => - KafkaMicroBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) - }.toArray + new KafkaMicroBatchInputPartition( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer + ): InputPartition[InternalRow] + }.asJava + } + + override def getStartOffset: Offset = { + KafkaSourceOffset(startPartitionOffsets) } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - KafkaMicroBatchReaderFactory + override def getEndOffset: Offset = { + KafkaSourceOffset(endPartitionOffsets) } override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } + override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema + override def commit(end: Offset): Unit = {} override def stop(): Unit = { @@ -300,23 +306,22 @@ private[kafka010] case class KafkaMicroBatchInputPartition( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartition + reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] { -private[kafka010] object KafkaMicroBatchReaderFactory extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val p = partition.asInstanceOf[KafkaMicroBatchInputPartition] - KafkaMicroBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs, - p.failOnDataLoss, p.reuseKafkaConsumer) - } + override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray + + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, + failOnDataLoss, reuseKafkaConsumer) } -/** A [[PartitionReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchPartitionReader( +/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchInputPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends PartitionReader[InternalRow] with Logging { + reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 28c9853bfea9..d225c1ea6b7f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,8 +30,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -45,9 +46,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with StreamingWriteSupportProvider - with ContinuousReadSupportProvider - with MicroBatchReadSupportProvider + with StreamWriteSupport + with ContinuousReadSupport + with MicroBatchReadSupport with Logging { import KafkaSourceProvider._ @@ -107,12 +108,13 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport]] to read - * batches of Kafka data in a micro-batch streaming query. + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches + * of Kafka data in a micro-batch streaming query. */ - override def createMicroBatchReadSupport( + override def createMicroBatchReader( + schema: Optional[StructType], metadataPath: String, - options: DataSourceOptions): KafkaMicroBatchReadSupport = { + options: DataSourceOptions): KafkaMicroBatchReader = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) @@ -138,7 +140,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaMicroBatchReadSupport( + new KafkaMicroBatchReader( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), options, @@ -148,12 +150,13 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport]] to read + * Creates a [[ContinuousInputPartitionReader]] to read * Kafka data in a continuous streaming query. */ - override def createContinuousReadSupport( + override def createContinuousReader( + schema: Optional[StructType], metadataPath: String, - options: DataSourceOptions): KafkaContinuousReadSupport = { + options: DataSourceOptions): KafkaContinuousReader = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -178,7 +181,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaContinuousReadSupport( + new KafkaContinuousReader( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), parameters, @@ -267,11 +270,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createStreamingWriteSupport( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { + options: DataSourceOptions): StreamWriter = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get @@ -282,7 +285,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister KafkaWriter.validateQuery( schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) - new KafkaStreamingWriteSupport(topic, producerParams, schema) + new KafkaStreamWriter(topic, producerParams, schema) } private def strategy(caseInsensitiveParams: Map[String, String]) = diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala similarity index 91% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 927c56d9ce82..97c577d5a8b9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** @@ -33,20 +33,20 @@ import org.apache.spark.sql.types.StructType case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[StreamingWriteSupport]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaStreamingWriteSupport( +class KafkaStreamWriter( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamingWriteSupport { + extends StreamWriter { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createStreamingWriterFactory(): KafkaStreamWriterFactory = + override def createWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} @@ -63,9 +63,9 @@ class KafkaStreamingWriteSupport( */ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamingDataWriterFactory { + extends DataWriterFactory[InternalRow] { - override def createWriter( + override def createDataWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index af510219a6f6..a0e5818dbbb6 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.producer.ProducerRecord import org.apache.spark.sql.Dataset -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.streaming.Trigger @@ -207,13 +207,11 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { testUtils.createTopic(topic2, partitions = 5) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.executedPlan.collectFirst { - case scan: DataSourceV2ScanExec - if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => - scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] - }.exists { config => + query.lastExecution.logical.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r + }.exists { r => // Ensure the new topic is present and the old topic is gone. - config.knownPartitions.exists(_.topic == topic2) + r.knownPartitions.exists(_.topic == topic2) }, s"query never reconfigured to new topic $topic2") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index fa6bdc20bd4f..fa1468a3943c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -46,10 +46,8 @@ trait KafkaContinuousTest extends KafkaSourceTest { testUtils.addPartitions(topic, newCount) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.executedPlan.collectFirst { - case scan: DataSourceV2ScanExec - if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => - scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] + query.lastExecution.logical.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 8e246dbbf5d7..65615fdb5b3e 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.Locale +import java.util.{Locale, Optional} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger @@ -28,7 +28,7 @@ import scala.collection.JavaConverters._ import scala.io.Source import scala.util.Random -import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord, RecordMetadata} +import org.apache.kafka.clients.producer.{ProducerRecord, RecordMetadata} import org.apache.kafka.common.TopicPartition import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ @@ -40,9 +40,11 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with KafkaTest { @@ -112,16 +114,14 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources: Seq[BaseStreamingSource] = { + val sources = { query.get.logicalPlan.collect { case StreamingExecutionRelation(source: KafkaSource, _) => source - case StreamingExecutionRelation(source: KafkaMicroBatchReadSupport, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case r: StreamingDataSourceV2Relation - if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => - r.readSupport.asInstanceOf[KafkaContinuousReadSupport] + case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader } }) }.distinct @@ -905,7 +905,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { makeSureGetOffsetCalled, AssertOnQuery { query => query.logicalPlan.collect { - case StreamingExecutionRelation(_: KafkaMicroBatchReadSupport, _) => true + case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true }.nonEmpty } ) @@ -930,16 +930,17 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "kafka.bootstrap.servers" -> testUtils.brokerAddress, "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} - val readSupport = provider.createMicroBatchReadSupport( - dir.getAbsolutePath, new DataSourceOptions(options.asJava)) - val config = readSupport.newScanConfigBuilder( - KafkaSourceOffset(Map(tp -> 0L)), - KafkaSourceOffset(Map(tp -> 100L))).build() - val inputPartitions = readSupport.planInputPartitions(config) + val reader = provider.createMicroBatchReader( + Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + reader.setOffsetRange( + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) + ) + val factories = reader.planInputPartitions().asScala .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) - withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") { - assert(inputPartitions.size == numPartitionsGenerated) - inputPartitions.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } + withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { + assert(factories.size == numPartitionsGenerated) + factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java new file mode 100644 index 000000000000..7df5a451ae5f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for continuous stream processing. + */ +@InterfaceStability.Evolving +public interface ContinuousReadSupport extends DataSourceV2 { + /** + * Creates a {@link ContinuousReader} to scan the data from this data source. + * + * @param schema the user provided schema, or empty() if none was provided + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + ContinuousReader createContinuousReader( + Optional schema, + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java deleted file mode 100644 index 824c290518ac..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for continuous stream processing. - * - * This interface is used to create {@link ContinuousReadSupport} instances when end users run - * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger. - */ -@InterfaceStability.Evolving -public interface ContinuousReadSupportProvider extends DataSourceV2 { - - /** - * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data - * source with a user specified schema, which is called by Spark at the beginning of each - * continuous streaming query. - * - * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user specified schema. - * - * @param schema the user provided schema. - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - default ContinuousReadSupport createContinuousReadSupport( - StructType schema, - String checkpointLocation, - DataSourceOptions options) { - return DataSourceV2Utils.failForUserSpecifiedSchema(this); - } - - /** - * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data - * source, which is called by Spark at the beginning of each continuous streaming query. - * - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - ContinuousReadSupport createContinuousReadSupport( - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index 6e31e84bf6c7..6234071320dc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -22,13 +22,9 @@ /** * The base interface for data source v2. Implementations must have a public, 0-arg constructor. * - * Note that this is an empty interface. Data source implementations must mix in interfaces such as - * {@link BatchReadSupportProvider} or {@link BatchWriteSupportProvider}, which can provide - * batch or streaming read/write support instances. Otherwise it's just a dummy data source which - * is un-readable/writable. - * - * If Spark fails to execute any methods in the implementations of this interface (by throwing an - * exception), the read action will fail and no Spark job will be submitted. + * Note that this is an empty interface. Data source implementations should mix-in at least one of + * the plug-in interfaces like {@link ReadSupport} and {@link WriteSupport}. Otherwise it's just + * a dummy data source which is un-readable/writable. */ @InterfaceStability.Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java new file mode 100644 index 000000000000..7f4a2c9593c7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide streaming micro-batch data reading ability. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupport extends DataSourceV2 { + /** + * Creates a {@link MicroBatchReader} to read batches of data from this data source in a + * streaming query. + * + * The execution engine will create a micro-batch reader at the start of a streaming query, + * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and + * then call stop() when the execution is complete. Note that a single query may have multiple + * executions due to restart or failure recovery. + * + * @param schema the user provided schema, or empty() if none was provided + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + MicroBatchReader createMicroBatchReader( + Optional schema, + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java deleted file mode 100644 index 61c08e7fa89d..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for micro-batch stream processing. - * - * This interface is used to create {@link MicroBatchReadSupport} instances when end users run - * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. - */ -@InterfaceStability.Evolving -public interface MicroBatchReadSupportProvider extends DataSourceV2 { - - /** - * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data - * source with a user specified schema, which is called by Spark at the beginning of each - * micro-batch streaming query. - * - * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user specified schema. - * - * @param schema the user provided schema. - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - default MicroBatchReadSupport createMicroBatchReadSupport( - StructType schema, - String checkpointLocation, - DataSourceOptions options) { - return DataSourceV2Utils.failForUserSpecifiedSchema(this); - } - - /** - * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data - * source, which is called by Spark at the beginning of each micro-batch streaming query. - * - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - MicroBatchReadSupport createMicroBatchReadSupport( - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java similarity index 59% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index f403dc619e86..80ac08ee5ff5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -18,44 +18,48 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; -import org.apache.spark.sql.sources.v2.reader.BatchReadSupport; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for batch processing. - * - * This interface is used to create {@link BatchReadSupport} instances when end users run - * {@code SparkSession.read.format(...).option(...).load()}. + * provide data reading ability and scan the data from the data source. */ @InterfaceStability.Evolving -public interface BatchReadSupportProvider extends DataSourceV2 { +public interface ReadSupport extends DataSourceV2 { /** - * Creates a {@link BatchReadSupport} instance to load the data from this data source with a user - * specified schema, which is called by Spark at the beginning of each batch query. - * - * Spark will call this method at the beginning of each batch query to create a - * {@link BatchReadSupport} instance. + * Creates a {@link DataSourceReader} to scan the data from this data source. * - * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user specified schema. + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. * * @param schema the user specified schema. * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. */ - default BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { - return DataSourceV2Utils.failForUserSpecifiedSchema(this); + default DataSourceReader createReader(StructType schema, DataSourceOptions options) { + String name; + if (this instanceof DataSourceRegister) { + name = ((DataSourceRegister) this).shortName(); + } else { + name = this.getClass().getName(); + } + throw new UnsupportedOperationException(name + " does not support user specified schema"); } /** - * Creates a {@link BatchReadSupport} instance to scan the data from this data source, which is - * called by Spark at the beginning of each batch query. + * Creates a {@link DataSourceReader} to scan the data from this data source. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. * * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - BatchReadSupport createBatchReadSupport(DataSourceOptions options); + DataSourceReader createReader(DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index bbe430e29926..926c6fd8fd22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -28,10 +28,9 @@ public interface SessionConfigSupport extends DataSourceV2 { /** - * Key prefix of the session configs to propagate, which is usually the data source name. Spark - * will extract all session configs that starts with `spark.datasource.$keyPrefix`, turn - * `spark.datasource.$keyPrefix.xxx -> yyy` into `xxx -> yyy`, and propagate them to all - * data source operations in this session. + * Key prefix of the session configs to propagate. Spark will extract all session configs that + * starts with `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy` + * into `xxx -> yyy`, and propagate them to all data source operations in this session. */ String keyPrefix(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java new file mode 100644 index 000000000000..a77b01497269 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability for structured streaming. + */ +@InterfaceStability.Evolving +public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink { + + /** + * Creates an optional {@link StreamWriter} to save the data to this data source. Data + * sources can return None if there is no writing needed to be done. + * + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link DataSourceWriter} can use this id to distinguish itself from others. + * @param schema the schema of the data to be written. + * @param mode the output mode which determines what successive epoch output means to this + * sink, please refer to {@link OutputMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + StreamWriter createStreamWriter( + String queryId, + StructType schema, + OutputMode mode, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java deleted file mode 100644 index f9ca85d8089b..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for structured streaming. - * - * This interface is used to create {@link StreamingWriteSupport} instances when end users run - * {@code Dataset.writeStream.format(...).option(...).start()}. - */ -@InterfaceStability.Evolving -public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { - - /** - * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is - * called by Spark at the beginning of each streaming query. - * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link StreamingWriteSupport} can use this id to distinguish itself from others. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive epoch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - StreamingWriteSupport createStreamingWriteSupport( - String queryId, - StructType schema, - OutputMode mode, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java similarity index 58% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index bd10c3353bf1..048787a7a0a0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -21,39 +21,33 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for batch processing. - * - * This interface is used to create {@link BatchWriteSupport} instances when end users run - * {@code Dataset.write.format(...).option(...).save()}. + * provide data writing ability and save the data to the data source. */ @InterfaceStability.Evolving -public interface BatchWriteSupportProvider extends DataSourceV2 { +public interface WriteSupport extends DataSourceV2 { /** - * Creates an optional {@link BatchWriteSupport} instance to save the data to this data source, - * which is called by Spark at the beginning of each batch query. + * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data + * sources can return None if there is no writing needed to be done according to the save mode. * - * Data sources can return None if there is no writing needed to be done according to the save - * mode. + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link BatchWriteSupport} can use this id to distinguish itself from others. + * @param writeUUID A unique string for the writing job. It's possible that there are many writing + * jobs running at the same time, and the returned {@link DataSourceWriter} can + * use this job id to distinguish itself from other jobs. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data * source, please refer to {@link SaveMode} for more details. * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. - * @return a write support to write data to this data source. + * @return a writer to append data to this data source */ - Optional createBatchWriteSupport( - String queryId, - StructType schema, - SaveMode mode, - DataSourceOptions options); + Optional createWriter( + String writeUUID, StructType schema, SaveMode mode, DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java deleted file mode 100644 index 452ee86675b4..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import org.apache.spark.annotation.InterfaceStability; - -/** - * An interface that defines how to load the data from data source for batch processing. - * - * The execution engine will get an instance of this interface from a data source provider - * (e.g. {@link org.apache.spark.sql.sources.v2.BatchReadSupportProvider}) at the start of a batch - * query, then call {@link #newScanConfigBuilder()} and create an instance of {@link ScanConfig}. - * The {@link ScanConfigBuilder} can apply operator pushdown and keep the pushdown result in - * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader - * factory to scan data from the data source with a Spark job. - */ -@InterfaceStability.Evolving -public interface BatchReadSupport extends ReadSupport { - - /** - * Returns a builder of {@link ScanConfig}. Spark will call this method and create a - * {@link ScanConfig} for each data scanning job. - * - * The builder can take some query specific information to do operators pushdown, and keep these - * information in the created {@link ScanConfig}. - * - * This is the first step of the data scan. All other methods in {@link BatchReadSupport} needs - * to take {@link ScanConfig} as an input. - */ - ScanConfigBuilder newScanConfigBuilder(); - - /** - * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. - */ - PartitionReaderFactory createReaderFactory(ScanConfig config); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java similarity index 61% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java index 4c0eedfddfe2..dcb87715d0b6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java @@ -18,13 +18,18 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; /** - * An interface for building the {@link ScanConfig}. Implementations can mixin those - * SupportsPushDownXYZ interfaces to do operator pushdown, and keep the operator pushdown result in - * the returned {@link ScanConfig}. + * A mix-in interface for {@link InputPartition}. Continuous input partitions can + * implement this interface to provide creating {@link InputPartitionReader} with particular offset. */ @InterfaceStability.Evolving -public interface ScanConfigBuilder { - ScanConfig build(); +public interface ContinuousInputPartition extends InputPartition { + /** + * Create an input partition reader with particular offset as its startOffset. + * + * @param offset offset want to set as the input partition reader's startOffset. + */ + InputPartitionReader createContinuousReader(PartitionOffset offset); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java new file mode 100644 index 000000000000..da98fab1284e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A data source reader that is returned by + * {@link ReadSupport#createReader(DataSourceOptions)} or + * {@link ReadSupport#createReader(StructType, DataSourceOptions)}. + * It can mix in various query optimization interfaces to speed up the data scan. The actual scan + * logic is delegated to {@link InputPartition}s, which are returned by + * {@link #planInputPartitions()}. + * + * There are mainly 3 kinds of query optimizations: + * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column + * pruning), etc. Names of these interfaces start with `SupportsPushDown`. + * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. + * Names of these interfaces start with `SupportsReporting`. + * 3. Columnar scan if implements {@link SupportsScanColumnarBatch}. + * + * If an exception was throw when applying any of these query optimizations, the action will fail + * and no Spark job will be submitted. + * + * Spark first applies all operator push-down optimizations that this data source supports. Then + * Spark collects information this data source reported for further optimizations. Finally Spark + * issues the scan request and does the actual data reading. + */ +@InterfaceStability.Evolving +public interface DataSourceReader { + + /** + * Returns the actual schema of this data source reader, which may be different from the physical + * schema of the underlying storage, as column pruning or other optimizations may happen. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + */ + StructType readSchema(); + + /** + * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for + * creating a data reader to output data of one RDD partition. The number of input partitions + * returned here is the same as the number of RDD partitions this scan outputs. + * + * Note that, this may not be a full scan if the data source reader mixes in other optimization + * interfaces like column pruning, filter push-down, etc. These optimizations are applied before + * Spark issues the scan request. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + */ + List> planInputPartitions(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 95c30de907e4..f2038d0de3ff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -22,18 +22,18 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A serializable representation of an input partition returned by - * {@link ReadSupport#planInputPartitions(ScanConfig)}. + * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is + * responsible for creating the actual data reader of one RDD partition. + * The relationship between {@link InputPartition} and {@link InputPartitionReader} + * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that {@link InputPartition} will be serialized and sent to executors, then - * {@link PartitionReader} will be created by - * {@link PartitionReaderFactory#createReader(InputPartition)} or - * {@link PartitionReaderFactory#createColumnarReader(InputPartition)} on executors to do - * the actual reading. So {@link InputPartition} must be serializable while {@link PartitionReader} - * doesn't need to be. + * Note that {@link InputPartition}s will be serialized and sent to executors, then + * {@link InputPartitionReader}s will be created on executors to do the actual reading. So + * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to + * be. */ @InterfaceStability.Evolving -public interface InputPartition extends Serializable { +public interface InputPartition extends Serializable { /** * The preferred locations where the input partition reader returned by this partition can run @@ -51,4 +51,12 @@ public interface InputPartition extends Serializable { default String[] preferredLocations() { return new String[0]; } + + /** + * Returns an input partition reader to do the actual reading work. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + */ + InputPartitionReader createPartitionReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java similarity index 67% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index 04ff8d0a19fc..f3ff7f5cc0f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,27 +23,31 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or - * {@link PartitionReaderFactory#createColumnarReader(InputPartition)}. It's responsible for - * outputting data for a RDD partition. + * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is + * responsible for outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow} - * for normal data sources, or {@link org.apache.spark.sql.vectorized.ColumnarBatch} for columnar - * data sources(whose {@link PartitionReaderFactory#supportColumnarReads(InputPartition)} - * returns true). + * for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data + * source readers that mix in {@link SupportsScanColumnarBatch}. */ @InterfaceStability.Evolving -public interface PartitionReader extends Closeable { +public interface InputPartitionReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + * * @throws IOException if failure happens during disk/network IO like reading files. */ boolean next() throws IOException; /** * Return the current record. This method should return same value until `next` is called. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. */ T get(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java deleted file mode 100644 index f35de9310eee..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.io.Serializable; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.vectorized.ColumnarBatch; - -/** - * A factory used to create {@link PartitionReader} instances. - * - * If Spark fails to execute any methods in the implementations of this interface or in the returned - * {@link PartitionReader} (by throwing an exception), corresponding Spark task would fail and - * get retried until hitting the maximum retry times. - */ -@InterfaceStability.Evolving -public interface PartitionReaderFactory extends Serializable { - - /** - * Returns a row-based partition reader to read data from the given {@link InputPartition}. - * - * Implementations probably need to cast the input partition to the concrete - * {@link InputPartition} class defined for the data source. - */ - PartitionReader createReader(InputPartition partition); - - /** - * Returns a columnar partition reader to read data from the given {@link InputPartition}. - * - * Implementations probably need to cast the input partition to the concrete - * {@link InputPartition} class defined for the data source. - */ - default PartitionReader createColumnarReader(InputPartition partition) { - throw new UnsupportedOperationException("Cannot create columnar reader."); - } - - /** - * Returns true if the given {@link InputPartition} should be read by Spark in a columnar way. - * This means, implementations must also implement {@link #createColumnarReader(InputPartition)} - * for the input partitions that this method returns true. - * - * As of Spark 2.4, Spark can only read all input partition in a columnar way, or none of them. - * Data source can't mix columnar and row-based partitions. This may be relaxed in future - * versions. - */ - default boolean supportColumnarReads(InputPartition partition) { - return false; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java deleted file mode 100644 index a58ddb288f1e..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.types.StructType; - -/** - * The base interface for all the batch and streaming read supports. Data sources should implement - * concrete read support interfaces like {@link BatchReadSupport}. - * - * If Spark fails to execute any methods in the implementations of this interface (by throwing an - * exception), the read action will fail and no Spark job will be submitted. - */ -@InterfaceStability.Evolving -public interface ReadSupport { - - /** - * Returns the full schema of this data source, which is usually the physical schema of the - * underlying storage. This full schema should not be affected by column pruning or other - * optimizations. - */ - StructType fullSchema(); - - /** - * Returns a list of {@link InputPartition input partitions}. Each {@link InputPartition} - * represents a data split that can be processed by one Spark task. The number of input - * partitions returned here is the same as the number of RDD partitions this scan outputs. - * - * Note that, this may not be a full scan if the data source supports optimization like filter - * push-down. Implementations should check the input {@link ScanConfig} and adjust the resulting - * {@link InputPartition input partitions}. - */ - InputPartition[] planInputPartitions(ScanConfig config); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java deleted file mode 100644 index 7462ce282058..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.types.StructType; - -/** - * An interface that carries query specific information for the data scanning job, like operator - * pushdown information and streaming query offsets. This is defined as an empty interface, and data - * sources should define their own {@link ScanConfig} classes. - * - * For APIs that take a {@link ScanConfig} as input, like - * {@link ReadSupport#planInputPartitions(ScanConfig)}, - * {@link BatchReadSupport#createReaderFactory(ScanConfig)} and - * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to - * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. - */ -@InterfaceStability.Evolving -public interface ScanConfig { - - /** - * Returns the actual schema of this data source reader, which may be different from the physical - * schema of the underlying storage, as column pruning or other optimizations may happen. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - */ - StructType readSchema(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index 44799c7d4913..031c7a73c367 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -23,7 +23,7 @@ /** * An interface to represent statistics for a data source, which is returned by - * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. + * {@link SupportsReportStatistics#estimateStatistics()}. */ @InterfaceStability.Evolving public interface Statistics { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 5e7985f645a0..7e0020f38a73 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,11 +21,11 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to - * push down filters to the data source and reduce the size of the data to be read. + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to push down filters to the data source and reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownFilters extends ScanConfigBuilder { +public interface SupportsPushDownFilters extends DataSourceReader { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index edb164937d6e..427b4d00a112 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { +public interface SupportsPushDownRequiredColumns extends DataSourceReader { /** * Applies column pruning w.r.t. the given requiredSchema. @@ -35,8 +35,8 @@ public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. * - * Note that, {@link ScanConfig#readSchema()} implementation should take care of the column - * pruning applied here. + * Note that, data source readers should update {@link DataSourceReader#readSchema()} after + * applying column pruning. */ void pruneColumns(StructType requiredSchema); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index db62cd451536..6b60da7c4dc1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -21,17 +21,17 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** - * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to - * report data partitioning and try to avoid shuffle at Spark side. + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, - * Spark may avoid adding a shuffle even if the reader does not implement this interface. + * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid + * adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving -public interface SupportsReportPartitioning extends ReadSupport { +public interface SupportsReportPartitioning extends DataSourceReader { /** * Returns the output data partitioning that this reader guarantees. */ - Partitioning outputPartitioning(ScanConfig config); + Partitioning outputPartitioning(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 1831488ba096..44d0ce3c6e74 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,18 +20,18 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to - * report statistics to Spark. + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to report statistics to Spark. * * As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the - * data source. Implementations that return more accurate statistics based on pushed operators will - * not improve query performance until the planner can push operators before getting stats. + * DataSourceReader. Implementations that return more accurate statistics based on pushed operators + * will not improve query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics extends ReadSupport { +public interface SupportsReportStatistics extends DataSourceReader { /** - * Returns the estimated statistics of this data source scan. + * Returns the estimated statistics of this data source. */ - Statistics estimateStatistics(ScanConfig config); + Statistics estimateStatistics(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java new file mode 100644 index 000000000000..f4da686740d1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to output {@link ColumnarBatch} and make the scan faster. + */ +@InterfaceStability.Evolving +public interface SupportsScanColumnarBatch extends DataSourceReader { + @Override + default List> planInputPartitions() { + throw new IllegalStateException( + "planInputPartitions not supported by default within SupportsScanColumnarBatch."); + } + + /** + * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data + * in batches. + */ + List> planBatchInputPartitions(); + + /** + * Returns true if the concrete data source reader can read data in batch according to the scan + * properties like required columns, pushes filters, etc. It's possible that the implementation + * can only support some certain columns with certain types. Users can overwrite this method and + * {@link #planInputPartitions()} to fallback to normal read path under some conditions. + */ + default boolean enableBatchRead() { + return true; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 6764d4b7665c..38ca5fc6387b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.PartitionReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link PartitionReader}. + * {@link InputPartitionReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index 364a3f553923..5e32ba6952e1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -18,14 +18,14 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.PartitionReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions (one {@link PartitionReader} outputs data for one + * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one * partition). * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link PartitionReader}). + * partition(the output records of a single {@link InputPartitionReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index fb0b6f1df43b..f460f6bfe3bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -19,13 +19,12 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by - * {@link SupportsReportPartitioning#outputPartitioning(ScanConfig)}. Note that this should work - * like a snapshot. Once created, it should be deterministic and always report the same number of + * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a + * snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java similarity index 60% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java index 9101c8a44d34..7b0ba0bbdda9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java @@ -18,20 +18,19 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.PartitionReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** - * A variation on {@link PartitionReader} for use with continuous streaming processing. + * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode. */ @InterfaceStability.Evolving -public interface ContinuousPartitionReader extends PartitionReader { - - /** - * Get the offset of the current record, or the start offset if no records have been read. - * - * The execution engine will call this method along with get() to keep track of the current - * offset. When an epoch ends, the offset of the previous record in each partition will be saved - * as a restart checkpoint. - */ - PartitionOffset getOffset(); +public interface ContinuousInputPartitionReader extends InputPartitionReader { + /** + * Get the offset of the current record, or the start offset if no records have been read. + * + * The execution engine will call this method along with get() to keep track of the current + * offset. When an epoch ends, the offset of the previous record in each partition will be saved + * as a restart checkpoint. + */ + PartitionOffset getOffset(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java deleted file mode 100644 index 2d9f1ca1686a..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory; -import org.apache.spark.sql.vectorized.ColumnarBatch; - -/** - * A variation on {@link PartitionReaderFactory} that returns {@link ContinuousPartitionReader} - * instead of {@link org.apache.spark.sql.sources.v2.reader.PartitionReader}. It's used for - * continuous streaming processing. - */ -@InterfaceStability.Evolving -public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { - @Override - ContinuousPartitionReader createReader(InputPartition partition); - - @Override - default ContinuousPartitionReader createColumnarReader(InputPartition partition) { - throw new UnsupportedOperationException("Cannot create columnar reader."); - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java deleted file mode 100644 index 9a3ad2eb8a80..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.ScanConfig; -import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; - -/** - * An interface that defines how to load the data from data source for continuous streaming - * processing. - * - * The execution engine will get an instance of this interface from a data source provider - * (e.g. {@link org.apache.spark.sql.sources.v2.ContinuousReadSupportProvider}) at the start of a - * streaming query, then call {@link #newScanConfigBuilder(Offset)} and create an instance of - * {@link ScanConfig} for the duration of the streaming query or until - * {@link #needsReconfiguration(ScanConfig)} is true. The {@link ScanConfig} will be used to create - * input partitions and reader factory to scan data with a Spark job for its duration. At the end - * {@link #stop()} will be called when the streaming execution is completed. Note that a single - * query may have multiple executions due to restart or failure recovery. - */ -@InterfaceStability.Evolving -public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource { - - /** - * Returns a builder of {@link ScanConfig}. Spark will call this method and create a - * {@link ScanConfig} for each data scanning job. - * - * The builder can take some query specific information to do operators pushdown, store streaming - * offsets, etc., and keep these information in the created {@link ScanConfig}. - * - * This is the first step of the data scan. All other methods in {@link ContinuousReadSupport} - * needs to take {@link ScanConfig} as an input. - */ - ScanConfigBuilder newScanConfigBuilder(Offset start); - - /** - * Returns a factory, which produces one {@link ContinuousPartitionReader} for one - * {@link InputPartition}. - */ - ContinuousPartitionReaderFactory createContinuousReaderFactory(ScanConfig config); - - /** - * Merge partitioned offsets coming from {@link ContinuousPartitionReader} instances - * for each partition to a single global offset. - */ - Offset mergeOffsets(PartitionOffset[] offsets); - - /** - * The execution engine will call this method in every epoch to determine if new input - * partitions need to be generated, which may be required if for example the underlying - * source system has had partitions added or removed. - * - * If true, the query will be shut down and restarted with a new {@link ContinuousReadSupport} - * instance. - */ - default boolean needsReconfiguration(ScanConfig config) { - return false; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java new file mode 100644 index 000000000000..6e960bedf802 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; + +import java.util.Optional; + +/** + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to allow reading in a continuous processing mode stream. + * + * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. + */ +@InterfaceStability.Evolving +public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { + /** + * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances + * for each partition to a single global offset. + */ + Offset mergeOffsets(PartitionOffset[] offsets); + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + Offset deserializeOffset(String json); + + /** + * Set the desired start offset for partitions created from this reader. The scan will + * start from the first record after the provided offset, or from an implementation-defined + * inferred starting point if no offset is provided. + */ + void setStartOffset(Optional start); + + /** + * Return the specified or inferred start offset for this reader. + * + * @throws IllegalStateException if setStartOffset has not been called + */ + Offset getStartOffset(); + + /** + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. + * + * If true, the query will be shut down and restarted with a new reader. + */ + default boolean needsReconfiguration() { + return false; + } + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java deleted file mode 100644 index edb0db11bff2..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.*; - -/** - * An interface that defines how to scan the data from data source for micro-batch streaming - * processing. - * - * The execution engine will get an instance of this interface from a data source provider - * (e.g. {@link org.apache.spark.sql.sources.v2.MicroBatchReadSupportProvider}) at the start of a - * streaming query, then call {@link #newScanConfigBuilder(Offset, Offset)} and create an instance - * of {@link ScanConfig} for each micro-batch. The {@link ScanConfig} will be used to create input - * partitions and reader factory to scan a micro-batch with a Spark job. At the end {@link #stop()} - * will be called when the streaming execution is completed. Note that a single query may have - * multiple executions due to restart or failure recovery. - */ -@InterfaceStability.Evolving -public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { - - /** - * Returns a builder of {@link ScanConfig}. Spark will call this method and create a - * {@link ScanConfig} for each data scanning job. - * - * The builder can take some query specific information to do operators pushdown, store streaming - * offsets, etc., and keep these information in the created {@link ScanConfig}. - * - * This is the first step of the data scan. All other methods in {@link MicroBatchReadSupport} - * needs to take {@link ScanConfig} as an input. - */ - ScanConfigBuilder newScanConfigBuilder(Offset start, Offset end); - - /** - * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. - */ - PartitionReaderFactory createReaderFactory(ScanConfig config); - - /** - * Returns the most recent offset available. - */ - Offset latestOffset(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java new file mode 100644 index 000000000000..0159c731762d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; + +import java.util.Optional; + +/** + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to indicate they allow micro-batch streaming reads. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. + */ +@InterfaceStability.Evolving +public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { + /** + * Set the desired offset range for input partitions created from this reader. Partition readers + * will generate only data within (`start`, `end`]; that is, from the first record after `start` + * to the record with offset `end`. + * + * @param start The initial offset to scan from. If not specified, scan from an + * implementation-specified start point, such as the earliest available record. + * @param end The last offset to include in the scan. If not specified, scan up to an + * implementation-defined endpoint, such as the last available offset + * or the start offset plus a target batch size. + */ + void setOffsetRange(Optional start, Optional end); + + /** + * Returns the specified (if explicitly set through setOffsetRange) or inferred start offset + * for this reader. + * + * @throws IllegalStateException if setOffsetRange has not been called + */ + Offset getStartOffset(); + + /** + * Return the specified (if explicitly set through setOffsetRange) or inferred end offset + * for this reader. + * + * @throws IllegalStateException if setOffsetRange has not been called + */ + Offset getEndOffset(); + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + Offset deserializeOffset(String json); + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index 6cf27734867c..e41c0351edc8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -20,8 +20,8 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An abstract representation of progress through a {@link MicroBatchReadSupport} or - * {@link ContinuousReadSupport}. + * An abstract representation of progress through a {@link MicroBatchReader} or + * {@link ContinuousReader}. * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java deleted file mode 100644 index 84872d1ebc26..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.sql.sources.v2.reader.ReadSupport; - -/** - * A base interface for streaming read support. This is package private and is invisible to data - * sources. Data sources should implement concrete streaming read support interfaces: - * {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}. - */ -interface StreamingReadSupport extends ReadSupport { - - /** - * Returns the initial offset for a streaming query to start reading from. Note that the - * streaming data source should not assume that it will start reading from its initial offset: - * if Spark is restarting an existing query, it will restart from the check-pointed offset rather - * than the initial one. - */ - Offset initialOffset(); - - /** - * Deserialize a JSON string into an Offset of the implementation-defined offset type. - * - * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader - */ - Offset deserializeOffset(String json); - - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0ec9e05d6a02..385fc294fea8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -18,13 +18,28 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.StreamWriteSupport; +import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.types.StructType; /** - * An interface that defines how to write the data to data source for batch processing. + * A data source writer that is returned by + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ + * {@link StreamWriteSupport#createStreamWriter( + * String, StructType, OutputMode, DataSourceOptions)}. + * It can mix in various writing optimization interfaces to speed up the data saving. The actual + * writing logic is delegated to {@link DataWriter}. + * + * If an exception was throw when applying any of these writing optimizations, the action will fail + * and no Spark job will be submitted. * * The writing procedure is: - * 1. Create a writer factory by {@link #createBatchWriterFactory()}, serialize and send it to all - * the partitions of the input data(RDD). + * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the + * partitions of the input data(RDD). * 2. For each partition, create the data writer, and write the data of the partition with this * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If * exception happens during the writing, call {@link DataWriter#abort()}. @@ -38,7 +53,7 @@ * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving -public interface BatchWriteSupport { +public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. @@ -46,7 +61,7 @@ public interface BatchWriteSupport { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - DataWriterFactory createBatchWriterFactory(); + DataWriterFactory createWriterFactory(); /** * Returns whether Spark should use the commit coordinator to ensure that at most one task for diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 5fb067966ee6..27dc5ea224fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -36,11 +36,11 @@ * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to - * {@link BatchWriteSupport#commit(WriterCommitMessage[])} with commit messages from other data + * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createWriter(int, long)} will receive a - * different `taskId`. Spark will call {@link BatchWriteSupport#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a + * different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task @@ -71,11 +71,11 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to - * {@link BatchWriteSupport#commit(WriterCommitMessage[])}. + * {@link DataSourceWriter#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after - * {@link BatchWriteSupport#commit(WriterCommitMessage[])} succeeds, which means this method - * should still "hide" the written data and ask the {@link BatchWriteSupport} at driver side to + * {@link DataSourceWriter#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceWriter} at driver side to * do the final commit via {@link WriterCommitMessage}. * * If this method fails (by throwing an exception), {@link #abort()} will be called and this @@ -93,7 +93,7 @@ public interface DataWriter { * failed. * * If this method fails(by throwing an exception), the underlying data source may have garbage - * that need to be cleaned by {@link BatchWriteSupport#abort(WriterCommitMessage[])} or manually, + * that need to be cleaned by {@link DataSourceWriter#abort(WriterCommitMessage[])} or manually, * but these garbage should not be visible to data source readers. * * @throws IOException if failure happens during disk/network IO like writing files. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 19a36dd23245..3d337b6e0bdf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -19,20 +19,18 @@ import java.io.Serializable; -import org.apache.spark.TaskContext; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.InternalRow; /** - * A factory of {@link DataWriter} returned by {@link BatchWriteSupport#createBatchWriterFactory()}, + * A factory of {@link DataWriter} returned by {@link DataSourceWriter#createWriterFactory()}, * which is responsible for creating and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer - * will be created on executors and do the actual writing. So this interface must be + * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be * serializable and {@link DataWriter} doesn't need to be. */ @InterfaceStability.Evolving -public interface DataWriterFactory extends Serializable { +public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data @@ -40,16 +38,19 @@ public interface DataWriterFactory extends Serializable { * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a * list. * - * If this method fails (by throwing an exception), the corresponding Spark write task would fail - * and get retried until hitting the maximum retry times. + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. * * @param partitionId A unique id of the RDD partition that the returned writer will process. * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run - * multiple tasks for the same partition (due to speculation or task failures, - * for example). + * @param taskId A unique identifier for a task that is performing the write of the partition + * data. Spark may run multiple tasks for the same partition (due to speculation + * or task failures, for example). + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. For non-streaming queries, + * this ID will always be 0. */ - DataWriter createWriter(int partitionId, long taskId); + DataWriter createDataWriter(int partitionId, long taskId, long epochId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 123335c414e9..9e38836c0edf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -19,16 +19,15 @@ import java.io.Serializable; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; import org.apache.spark.annotation.InterfaceStability; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side - * as the input parameter of {@link BatchWriteSupport#commit(WriterCommitMessage[])} or - * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}. + * as the input parameter of {@link DataSourceWriter#commit(WriterCommitMessage[])}. * - * This is an empty interface, data sources should define their own message class and use it when - * generating messages at executor side and handling the messages at driver side. + * This is an empty interface, data sources should define their own message class and use it in + * their {@link DataWriter#commit()} and {@link DataSourceWriter#commit(WriterCommitMessage[])} + * implementations. */ @InterfaceStability.Evolving public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java similarity index 78% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index 3fdfac5e1c84..a316b2a4c1d8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -18,36 +18,27 @@ package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * An interface that defines how to write the data to data source for streaming processing. + * A {@link DataSourceWriter} for use with structured streaming. * * Streaming queries are divided into intervals of data called epochs, with a monotonically * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ @InterfaceStability.Evolving -public interface StreamingWriteSupport { - - /** - * Creates a writer factory which will be serialized and sent to executors. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - */ - StreamingDataWriterFactory createStreamingWriterFactory(); - +public interface StreamWriter extends DataSourceWriter { /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by * {@link DataWriter#commit()}. * * If this method fails (by throwing an exception), this writing job is considered to have been - * failed, and the execution engine will attempt to call - * {@link #abort(long, WriterCommitMessage[])}. + * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. * - * The execution engine may call `commit` multiple times for the same epoch in some circumstances. + * The execution engine may call commit() multiple times for the same epoch in some circumstances. * To support exactly-once data semantics, implementations must ensure that multiple commits for * the same epoch are idempotent. */ @@ -55,8 +46,7 @@ public interface StreamingWriteSupport { /** * Aborts this writing job because some data writers are failed and keep failing when retried, or - * the Spark job fails with some unknown reasons, or {@link #commit(long, WriterCommitMessage[])} - * fails. + * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. @@ -68,4 +58,14 @@ public interface StreamingWriteSupport { * clean up the data left by data writers. */ void abort(long epochId, WriterCommitMessage[] messages); + + default void commit(WriterCommitMessage[] messages) { + throw new UnsupportedOperationException( + "Commit without epoch should not be called with StreamWriter"); + } + + default void abort(WriterCommitMessage[] messages) { + throw new UnsupportedOperationException( + "Abort without epoch should not be called with StreamWriter"); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java deleted file mode 100644 index a4da24fc5ae6..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.writer.streaming; - -import java.io.Serializable; - -import org.apache.spark.TaskContext; -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.sources.v2.writer.DataWriter; - -/** - * A factory of {@link DataWriter} returned by - * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating - * and initializing the actual data writer at executor side. - * - * Note that, the writer factory will be serialized and sent to executors, then the data writer - * will be created on executors and do the actual writing. So this interface must be - * serializable and {@link DataWriter} doesn't need to be. - */ -@InterfaceStability.Evolving -public interface StreamingDataWriterFactory extends Serializable { - - /** - * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data - * object instance when sending data to the data writer, for better performance. Data writers - * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a - * list. - * - * If this method fails (by throwing an exception), the corresponding Spark write task would fail - * and get retried until hitting the maximum retry times. - * - * @param partitionId A unique id of the RDD partition that the returned writer will process. - * Usually Spark processes many RDD partitions at the same time, - * implementations should use the partition id to distinguish writers for - * different partitions. - * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run - * multiple tasks for the same partition (due to speculation or task failures, - * for example). - * @param epochId A monotonically increasing id for streaming queries that are split in to - * discrete periods of execution. - */ - DataWriter createWriter(int partitionId, long taskId, long epochId); -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0cfcc45fb3d3..5b3b5c2451aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -194,7 +194,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance().asInstanceOf[DataSourceV2] - if (ds.isInstanceOf[BatchReadSupportProvider]) { + if (ds.isInstanceOf[ReadSupport]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) val pathsOption = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index eca2d5b97190..650c91790a75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -240,7 +240,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { if (classOf[DataSourceV2].isAssignableFrom(cls)) { val source = cls.newInstance().asInstanceOf[DataSourceV2] source match { - case provider: BatchWriteSupportProvider => + case ws: WriteSupport => val options = extraOptions ++ DataSourceV2Utils.extractSessionConfigs(source, df.sparkSession.sessionState.conf) @@ -251,10 +251,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } else { - val writer = provider.createBatchWriteSupport( - UUID.randomUUID().toString, - df.logicalPlan.output.toStructType, - mode, + val writer = ws.createWriter( + UUID.randomUUID.toString, df.logicalPlan.output.toStructType, mode, new DataSourceOptions(options.asJava)) if (writer.isPresent) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index f62f7349d1da..782829887c44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,22 +17,19 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark._ +import scala.reflect.ClassTag + +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.sources.v2.reader.InputPartition -class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T]) extends Partition with Serializable -// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for -// columnar scan. -class DataSourceRDD( +class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val inputPartitions: Seq[InputPartition], - partitionReaderFactory: PartitionReaderFactory, - columnarReads: Boolean) - extends RDD[InternalRow](sc, Nil) { + @transient private val inputPartitions: Seq[InputPartition[T]]) + extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { @@ -40,21 +37,11 @@ class DataSourceRDD( }.toArray } - private def castPartition(split: Partition): DataSourceRDDPartition = split match { - case p: DataSourceRDDPartition => p - case _ => throw new SparkException(s"[BUG] Not a DataSourceRDDPartition: $split") - } - - override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { - val inputPartition = castPartition(split).inputPartition - val reader: PartitionReader[_] = if (columnarReads) { - partitionReaderFactory.createColumnarReader(inputPartition) - } else { - partitionReaderFactory.createReader(inputPartition) - } - + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition + .createPartitionReader() context.addTaskCompletionListener[Unit](_ => reader.close()) - val iter = new Iterator[Any] { + val iter = new Iterator[T] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -64,7 +51,7 @@ class DataSourceRDD( valuePrepared } - override def next(): Any = { + override def next(): T = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -72,11 +59,10 @@ class DataSourceRDD( reader.get() } } - // TODO: SPARK-25083 remove the type erasure hack in data source scan - new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) + new InterruptibleIterator(context, iter) } override def getPreferredLocations(split: Partition): Seq[String] = { - castPartition(split).inputPartition.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index f7e29593a635..abc5fb979250 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -27,21 +27,21 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, BatchWriteSupportProvider, DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.reader.{BatchReadSupport, ReadSupport, ScanConfigBuilder, SupportsReportStatistics} -import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter import org.apache.spark.sql.types.StructType /** * A logical plan representing a data source v2 scan. * * @param source An instance of a [[DataSourceV2]] implementation. - * @param options The options for this scan. Used to create fresh [[BatchWriteSupport]]. - * @param userSpecifiedSchema The user-specified schema for this scan. + * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. + * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh + * [[DataSourceReader]]. */ case class DataSourceV2Relation( source: DataSourceV2, - readSupport: BatchReadSupport, output: Seq[AttributeReference], options: Map[String, String], tableIdent: Option[TableIdentifier] = None, @@ -58,12 +58,13 @@ case class DataSourceV2Relation( override def simpleString: String = "RelationV2 " + metadataString - def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema) + def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) - override def computeStats(): Statistics = readSupport match { + def newWriter(): DataSourceWriter = source.createWriter(options, schema) + + override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => - val statistics = r.estimateStatistics(readSupport.newScanConfigBuilder().build()) - Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + Statistics(sizeInBytes = r.estimateStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -84,8 +85,7 @@ case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], source: DataSourceV2, options: Map[String, String], - readSupport: ReadSupport, - scanConfigBuilder: ScanConfigBuilder) + reader: DataSourceReader) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { override def isStreaming: Boolean = true @@ -99,8 +99,7 @@ case class StreamingDataSourceV2Relation( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: StreamingDataSourceV2Relation => - output == other.output && readSupport.getClass == other.readSupport.getClass && - options == other.options + output == other.output && reader.getClass == other.reader.getClass && options == other.options case _ => false } @@ -108,10 +107,9 @@ case class StreamingDataSourceV2Relation( Seq(output, source, options).hashCode() } - override def computeStats(): Statistics = readSupport match { + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => - val statistics = r.estimateStatistics(scanConfigBuilder.build()) - Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + Statistics(sizeInBytes = r.estimateStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -119,19 +117,19 @@ case class StreamingDataSourceV2Relation( object DataSourceV2Relation { private implicit class SourceHelpers(source: DataSourceV2) { - def asReadSupportProvider: BatchReadSupportProvider = { + def asReadSupport: ReadSupport = { source match { - case provider: BatchReadSupportProvider => - provider + case support: ReadSupport => + support case _ => throw new AnalysisException(s"Data source is not readable: $name") } } - def asWriteSupportProvider: BatchWriteSupportProvider = { + def asWriteSupport: WriteSupport = { source match { - case provider: BatchWriteSupportProvider => - provider + case support: WriteSupport => + support case _ => throw new AnalysisException(s"Data source is not writable: $name") } @@ -146,26 +144,23 @@ object DataSourceV2Relation { } } - def createReadSupport( + def createReader( options: Map[String, String], - userSpecifiedSchema: Option[StructType]): BatchReadSupport = { + userSpecifiedSchema: Option[StructType]): DataSourceReader = { val v2Options = new DataSourceOptions(options.asJava) userSpecifiedSchema match { case Some(s) => - asReadSupportProvider.createBatchReadSupport(s, v2Options) + asReadSupport.createReader(s, v2Options) case _ => - asReadSupportProvider.createBatchReadSupport(v2Options) + asReadSupport.createReader(v2Options) } } - def createWriteSupport( + def createWriter( options: Map[String, String], - schema: StructType): BatchWriteSupport = { - asWriteSupportProvider.createBatchWriteSupport( - UUID.randomUUID().toString, - schema, - SaveMode.Append, - new DataSourceOptions(options.asJava)).get + schema: StructType): DataSourceWriter = { + val v2Options = new DataSourceOptions(options.asJava) + asWriteSupport.createWriter(UUID.randomUUID.toString, schema, SaveMode.Append, v2Options).get } } @@ -174,16 +169,15 @@ object DataSourceV2Relation { options: Map[String, String], tableIdent: Option[TableIdentifier] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val readSupport = source.createReadSupport(options, userSpecifiedSchema) - val output = readSupport.fullSchema().toAttributes + val reader = source.createReader(options, userSpecifiedSchema) val ident = tableIdent.orElse(tableFromOptions(options)) DataSourceV2Relation( - source, readSupport, output, options, ident, userSpecifiedSchema) + source, reader.readSchema().toAttributes, options, ident, userSpecifiedSchema) } private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { options - .get(DataSourceOptions.TABLE_KEY) - .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) + .get(DataSourceOptions.TABLE_KEY) + .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 04a97735d024..c8494f97f176 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -26,7 +28,8 @@ import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeSta import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader +import org.apache.spark.sql.vectorized.ColumnarBatch /** * Physical plan node for scanning data from a data source. @@ -36,8 +39,7 @@ case class DataSourceV2ScanExec( @transient source: DataSourceV2, @transient options: Map[String, String], @transient pushedFilters: Seq[Expression], - @transient readSupport: ReadSupport, - @transient scanConfig: ScanConfig) + @transient reader: DataSourceReader) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { override def simpleString: String = "ScanV2 " + metadataString @@ -45,8 +47,7 @@ case class DataSourceV2ScanExec( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: DataSourceV2ScanExec => - output == other.output && readSupport.getClass == other.readSupport.getClass && - options == other.options + output == other.output && reader.getClass == other.reader.getClass && options == other.options case _ => false } @@ -54,39 +55,36 @@ case class DataSourceV2ScanExec( Seq(output, source, options).hashCode() } - override def outputPartitioning: physical.Partitioning = readSupport match { - case _ if partitions.length == 1 => + override def outputPartitioning: physical.Partitioning = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 => + SinglePartition + + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 => + SinglePartition + + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 => SinglePartition case s: SupportsReportPartitioning => new DataSourcePartitioning( - s.outputPartitioning(scanConfig), AttributeMap(output.map(a => a -> a.name))) + s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig) - - private lazy val readerFactory = readSupport match { - case r: BatchReadSupport => r.createReaderFactory(scanConfig) - case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig) - case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig) - case _ => throw new IllegalStateException("unknown read support: " + readSupport) + private lazy val partitions: Seq[InputPartition[InternalRow]] = { + reader.planInputPartitions().asScala } - // TODO: clean this up when we have dedicated scan plan for continuous streaming. - override val supportsBatch: Boolean = { - require(partitions.forall(readerFactory.supportColumnarReads) || - !partitions.exists(readerFactory.supportColumnarReads), - "Cannot mix row-based and columnar input partitions.") - - partitions.exists(readerFactory.supportColumnarReads) + private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + assert(!reader.isInstanceOf[ContinuousReader], + "continuous stream reader does not support columnar read yet.") + r.planBatchInputPartitions().asScala } - private lazy val inputRDD: RDD[InternalRow] = readSupport match { - case _: ContinuousReadSupport => - assert(!supportsBatch, - "continuous stream reader does not support columnar read yet.") + private lazy val inputRDD: RDD[InternalRow] = reader match { + case _: ContinuousReader => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -95,17 +93,22 @@ case class DataSourceV2ScanExec( sparkContext, sqlContext.conf.continuousStreamingExecutorQueueSize, sqlContext.conf.continuousStreamingExecutorPollIntervalMs, - partitions, - schema, - readerFactory.asInstanceOf[ContinuousPartitionReaderFactory]) + partitions).asInstanceOf[RDD[InternalRow]] + + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]] case _ => - new DataSourceRDD( - sparkContext, partitions, readerFactory.asInstanceOf[PartitionReaderFactory], supportsBatch) + new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]] } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + override val supportsBatch: Boolean = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => true + case _ => false + } + override protected def needsUnsafeRowConversion: Boolean = false override protected def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 9a3109e7c199..9d97d3b58f30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Rep import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader object DataSourceV2Strategy extends Strategy { @@ -37,9 +37,9 @@ object DataSourceV2Strategy extends Strategy { * @return pushed filter and post-scan filters. */ private def pushFilters( - configBuilder: ScanConfigBuilder, + reader: DataSourceReader, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - configBuilder match { + reader match { case r: SupportsPushDownFilters => // A map from translated data source filters to original catalyst filter expressions. val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] @@ -71,43 +71,41 @@ object DataSourceV2Strategy extends Strategy { /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * - * @return the created `ScanConfig`(since column pruning is the last step of operator pushdown), - * and new output attributes after column pruning. + * @return new output attributes after column pruning. */ // TODO: nested column pruning. private def pruneColumns( - configBuilder: ScanConfigBuilder, + reader: DataSourceReader, relation: DataSourceV2Relation, - exprs: Seq[Expression]): (ScanConfig, Seq[AttributeReference]) = { - configBuilder match { + exprs: Seq[Expression]): Seq[AttributeReference] = { + reader match { case r: SupportsPushDownRequiredColumns => val requiredColumns = AttributeSet(exprs.flatMap(_.references)) val neededOutput = relation.output.filter(requiredColumns.contains) if (neededOutput != relation.output) { r.pruneColumns(neededOutput.toStructType) - val config = r.build() val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - config -> config.readSchema().toAttributes.map { + r.readSchema().toAttributes.map { // We have to keep the attribute id during transformation. a => a.withExprId(nameToAttr(a.name).exprId) } } else { - r.build() -> relation.output + relation.output } - case _ => configBuilder.build() -> relation.output + case _ => relation.output } } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val configBuilder = relation.readSupport.newScanConfigBuilder() + val reader = relation.newReader() // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFilters) = pushFilters(configBuilder, filters) - val (config, output) = pruneColumns(configBuilder, relation, project ++ postScanFilters) + val (pushedFilters, postScanFilters) = pushFilters(reader, filters) + val output = pruneColumns(reader, relation, project ++ postScanFilters) logInfo( s""" |Pushing operators to ${relation.source.getClass} @@ -117,12 +115,7 @@ object DataSourceV2Strategy extends Strategy { """.stripMargin) val scan = DataSourceV2ScanExec( - output, - relation.source, - relation.options, - pushedFilters, - relation.readSupport, - config) + output, relation.source, relation.options, pushedFilters, reader) val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) @@ -131,26 +124,22 @@ object DataSourceV2Strategy extends Strategy { ProjectExec(project, withFilter) :: Nil case r: StreamingDataSourceV2Relation => - // TODO: support operator pushdown for streaming data sources. - val scanConfig = r.scanConfigBuilder.build() // ensure there is a projection, which will produce unsafe rows required by some operators ProjectExec(r.output, - DataSourceV2ScanExec( - r.output, r.source, r.options, r.pushedFilters, r.readSupport, scanConfig)) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => - WriteToDataSourceV2Exec(r.newWriteSupport(), planLater(query)) :: Nil + WriteToDataSourceV2Exec(r.newWriter(), planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil case Repartition(1, false, child) => - val isContinuous = child.find { - case s: StreamingDataSourceV2Relation => s.readSupport.isInstanceOf[ContinuousReadSupport] - case _ => false + val isContinuous = child.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r }.isDefined if (isContinuous) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index e9cc3991155c..5267f5f1580c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,7 +21,6 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} private[sql] object DataSourceV2Utils extends Logging { @@ -56,12 +55,4 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } - - def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { - val name = ds match { - case register: DataSourceRegister => register.shortName() - case _ => ds.getClass.getName - } - throw new UnsupportedOperationException(name + " source does not support user-specified schema") - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index c3f7b690ef63..59ebb9bc5431 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -23,11 +23,15 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** @@ -35,8 +39,7 @@ import org.apache.spark.util.Utils * specific logical plans, like [[org.apache.spark.sql.catalyst.plans.logical.AppendData]]. */ @deprecated("Use specific logical plans like AppendData instead", "2.4.0") -case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPlan) - extends LogicalPlan { +case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } @@ -44,48 +47,46 @@ case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPl /** * The physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan) - extends SparkPlan { - +case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writeSupport.createBatchWriterFactory() - val useCommitCoordinator = writeSupport.useCommitCoordinator + val writeTask = writer.createWriterFactory() + val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) - logInfo(s"Start processing data source write support: $writeSupport. " + + logInfo(s"Start processing data source writer: $writer. " + s"The input RDD has ${messages.length} partitions.") try { sparkContext.runJob( rdd, (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message - writeSupport.onDataWriterCommit(message) + writer.onDataWriterCommit(message) } ) - logInfo(s"Data source write support $writeSupport is committing.") - writeSupport.commit(messages) - logInfo(s"Data source write support $writeSupport committed.") + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { case cause: Throwable => - logError(s"Data source write support $writeSupport is aborting.") + logError(s"Data source writer $writer is aborting.") try { - writeSupport.abort(messages) + writer.abort(messages) } catch { case t: Throwable => - logError(s"Data source write support $writeSupport failed to abort.") + logError(s"Data source writer $writer failed to abort.") cause.addSuppressed(t) throw new SparkException("Writing job failed.", cause) } - logError(s"Data source write support $writeSupport aborted.") + logError(s"Data source writer $writer aborted.") cause match { // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) @@ -99,7 +100,7 @@ case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: Spark object DataWritingSparkTask extends Logging { def run( - writerFactory: DataWriterFactory, + writeTask: DataWriterFactory[InternalRow], context: TaskContext, iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { @@ -108,7 +109,8 @@ object DataWritingSparkTask extends Logging { val partId = context.partitionId() val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() - val dataWriter = writerFactory.createWriter(partId, taskId) + val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") + val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index b1cafd67820c..2fef2db361e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.Optional + import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} @@ -26,9 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -49,8 +51,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readSupportToDataSourceMap = - MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])] + private val readerToDataSourceMap = + MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -89,19 +91,20 @@ class MicroBatchExecution( StreamingExecutionRelation(source, output)(sparkSession) }) case s @ StreamingRelationV2( - dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if + dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val readSupport = dataSourceV2.createMicroBatchReadSupport( + val reader = dataSourceV2.createMicroBatchReader( + Optional.empty(), // user specified schema metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options - logInfo(s"Using MicroBatchReadSupport [$readSupport] from " + + readerToDataSourceMap(reader) = dataSourceV2 -> options + logInfo(s"Using MicroBatchReader [$reader] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") - StreamingExecutionRelation(readSupport, output)(sparkSession) + StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { @@ -341,19 +344,19 @@ class MicroBatchExecution( reportTimeTaken("getOffset") { (s, s.getOffset) } - case s: RateControlMicroBatchReadSupport => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("latestOffset") { - val startOffset = availableOffsets - .get(s).map(off => s.deserializeOffset(off.json)) - .getOrElse(s.initialOffset()) - (s, Option(s.latestOffset(startOffset))) - } - case s: MicroBatchReadSupport => + case s: MicroBatchReader => updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("latestOffset") { - (s, Option(s.latestOffset())) + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } + + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -393,8 +396,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) - case (readSupport: MicroBatchReadSupport, off) => - readSupport.commit(readSupport.deserializeOffset(off.json)) + case (reader: MicroBatchReader, off) => + reader.commit(reader.deserializeOffset(off.json)) case (src, _) => throw new IllegalArgumentException( s"Unknown source is found at constructNextBatch: $src") @@ -438,34 +441,30 @@ class MicroBatchExecution( s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch.logicalPlan) - - // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but - // to be compatible with streaming source v1, we return a logical plan as a new batch here. - case (readSupport: MicroBatchReadSupport, available) - if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(readSupport).map { - off => readSupport.deserializeOffset(off.json) - } - val endOffset: OffsetV2 = available match { - case v1: SerializedOffset => readSupport.deserializeOffset(v1.json) + case (reader: MicroBatchReader, available) + if committedOffsets.get(reader).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + val availableV2: OffsetV2 = available match { + case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - val startOffset = current.getOrElse(readSupport.initialOffset) - val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset) - logDebug(s"Retrieving data from $readSupport: $current -> $endOffset") + reader.setOffsetRange( + toJava(current), + Optional.of(availableV2)) + logDebug(s"Retrieving data from $reader: $current -> $availableV2") - val (source, options) = readSupport match { + val (source, options) = reader match { // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` // implementation. We provide a fake one here for explain. case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] // Provide a fake value here just in case something went wrong, e.g. the reader gives // a wrong `equals` implementation. - case _ => readSupportToDataSourceMap.getOrElse(readSupport, { + case _ => readerToDataSourceMap.getOrElse(reader, { FakeDataSourceV2 -> Map.empty[String, String] }) } - Some(readSupport -> StreamingDataSourceV2Relation( - readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder)) + Some(reader -> StreamingDataSourceV2Relation( + reader.readSchema().toAttributes, source, options, reader)) case _ => None } } @@ -499,13 +498,13 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamingWriteSupportProvider => - val writer = s.createStreamingWriteSupport( + case s: StreamWriteSupport => + val writer = s.createStreamWriter( s"$runId", newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - WriteToDataSourceV2(new MicroBatchWritSupport(currentBatchId, writer), newAttributePlan) + WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -531,7 +530,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamingWriteSupportProvider => + case _: StreamWriteSupport => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } @@ -555,6 +554,10 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } + + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { + Optional.ofNullable(scalaOption.orNull) + } } object MicroBatchExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d4b50655c721..6a380ab89ff7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -251,7 +251,7 @@ trait ProgressReporter extends Logging { // Check whether the streaming query's logical plan has only V2 data sources val allStreamingLeaves = logicalPlan.collect { case s: StreamingExecutionRelation => s } - allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } } if (onlyDataSourceV2Sources) { @@ -278,7 +278,7 @@ trait ProgressReporter extends Logging { new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() lastExecution.executedPlan.collectLeaves().foreach { - case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => + case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => uniqueStreamingExecLeavesMap.put(s, s) case _ => } @@ -286,7 +286,7 @@ trait ProgressReporter extends Logging { val sourceToInputRowsTuples = uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = execLeaf.readSupport.asInstanceOf[BaseStreamingSource] + val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] source -> numRows }.toSeq logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala deleted file mode 100644 index 1be071614d92..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder} -import org.apache.spark.sql.types.StructType - -/** - * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to - * carry schema and offsets for streaming data sources. - */ -class SimpleStreamingScanConfigBuilder( - schema: StructType, - start: Offset, - end: Option[Offset] = None) - extends ScanConfigBuilder { - - override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end) -} - -case class SimpleStreamingScanConfig( - readSchema: StructType, - start: Offset, - end: Option[Offset]) - extends ScanConfig diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 4b696dfa5735..24195b5657e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -83,7 +83,7 @@ case class StreamingExecutionRelation( // We have to pack in the V1 data source as a shim, for the case when a source implements // continuous processing (which is always V2) but only has V1 microbatch support. We don't -// know at read time whether the query is continuous or not, so we need to be able to +// know at read time whether the query is conntinuous or not, so we need to be able to // swap a V1 relation back in. /** * Used to link a [[DataSourceV2]] into a streaming @@ -113,7 +113,7 @@ case class StreamingRelationV2( * Used to link a [[DataSourceV2]] into a continuous processing execution. */ case class ContinuousExecutionRelation( - source: ContinuousReadSupportProvider, + source: ContinuousReadSupport, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9c5c16f4f5d1..cfba1001c6de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -31,16 +31,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with StreamingWriteSupportProvider + with StreamWriteSupport with DataSourceRegister with CreatableRelationProvider { - override def createStreamingWriteSupport( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new ConsoleWriteSupport(schema, options) + options: DataSourceOptions): StreamWriter = { + new ConsoleWriter(schema, options) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index b68f67e0b22d..554a0b0573f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -21,13 +21,12 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousPartitionReaderFactory -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( val index: Int, - val inputPartition: InputPartition) + val inputPartition: InputPartition[InternalRow]) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -50,22 +49,15 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - private val inputPartitions: Seq[InputPartition], - schema: StructType, - partitionReaderFactory: ContinuousPartitionReaderFactory) + private val readerInputPartitions: Seq[InputPartition[InternalRow]]) extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - inputPartitions.zipWithIndex.map { + readerInputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } - private def castPartition(split: Partition): ContinuousDataSourceRDDPartition = split match { - case p: ContinuousDataSourceRDDPartition => p - case _ => throw new SparkException(s"[BUG] Not a ContinuousDataSourceRDDPartition: $split") - } - /** * Initialize the shared reader for this partition if needed, then read rows from it until * it returns null to signal the end of the epoch. @@ -77,12 +69,10 @@ class ContinuousDataSourceRDD( } val readerForPartition = { - val partition = castPartition(split) + val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] if (partition.queueReader == null) { - val partitionReader = partitionReaderFactory.createReader( - partition.inputPartition) - partition.queueReader = new ContinuousQueuedDataReader( - partition.index, partitionReader, schema, context, dataQueueSize, epochPollIntervalMs) + partition.queueReader = + new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader @@ -103,6 +93,17 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - castPartition(split).inputPartition.preferredLocations() + split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getContinuousReader( + reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = { + reader match { + case r: ContinuousInputPartitionReader[InternalRow] => r + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 4ddebb33b79d..140cec64fffb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,12 +29,13 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} class ContinuousExecution( @@ -42,7 +43,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamingWriteSupportProvider, + sink: StreamWriteSupport, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -52,7 +53,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq() + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. @@ -62,8 +63,7 @@ class ContinuousExecution( val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( - source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) => - // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration? + source: ContinuousReadSupport, _, extraReaderOptions, output, _) => toExecutionRelationMap.getOrElseUpdate(r, { ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) @@ -148,7 +148,8 @@ class ContinuousExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 - dataSource.createContinuousReadSupport( + dataSource.createContinuousReader( + java.util.Optional.empty[StructType](), metadataPath, new DataSourceOptions(extraReaderOptions.asJava)) } @@ -159,9 +160,9 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { case ContinuousExecutionRelation(source, options, output) => - val readSupport = continuousSources(insertedSourceId) + val reader = continuousSources(insertedSourceId) insertedSourceId += 1 - val newOutput = readSupport.fullSchema().toAttributes + val newOutput = reader.readSchema().toAttributes assert(output.size == newOutput.size, s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + @@ -169,10 +170,9 @@ class ContinuousExecution( replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) - val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json)) - val startOffset = realOffset.getOrElse(readSupport.initialOffset) - val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset) - StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) + val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) + reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) + StreamingDataSourceV2Relation(newOutput, source, options, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -185,13 +185,17 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamingWriteSupport( + val writer = sink.createStreamWriter( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) + val reader = withSink.collect { + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + }.head + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionForQuery, @@ -204,11 +208,6 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } - val (readSupport, scanConfig) = lastExecution.executedPlan.collect { - case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => - scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig - }.head - sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) // Add another random ID on top of the run ID, to distinguish epoch coordinators across @@ -224,16 +223,14 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { triggerExecutor.execute(() => { startTrigger() - val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) && - state.compareAndSet(ACTIVE, RECONFIGURING) - if (shouldReconfigure) { + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { if (queryExecutionThread.isAlive) { queryExecutionThread.interrupt() } @@ -283,12 +280,10 @@ class ContinuousExecution( * Report ending partition offsets for the given reader at the given epoch. */ def addOffset( - epoch: Long, - readSupport: ContinuousReadSupport, - partitionOffsets: Seq[PartitionOffset]): Unit = { + epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray) + val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) val oldOffset = synchronized { offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) offsetLog.get(epoch - 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index 65c5fc63c2f4..ec1dabd7da3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -25,9 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, PartitionOffset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset import org.apache.spark.util.ThreadUtils /** @@ -38,14 +37,15 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partitionIndex: Int, - reader: ContinuousPartitionReader[InternalRow], - schema: StructType, + partition: ContinuousDataSourceRDDPartition, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { + private val reader = partition.inputPartition.createPartitionReader() + // Important sequencing - we must get our starting point before the provider threads start running - private var currentOffset: PartitionOffset = reader.getOffset + private var currentOffset: PartitionOffset = + ContinuousDataSourceRDD.getContinuousReader(reader).getOffset /** * The record types in the read buffer. @@ -66,7 +66,7 @@ class ContinuousQueuedDataReader( epochMarkerExecutor.scheduleWithFixedDelay( epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - private val dataReaderThread = new DataReaderThread(schema) + private val dataReaderThread = new DataReaderThread dataReaderThread.setDaemon(true) dataReaderThread.start() @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - partitionIndex, EpochTracker.getCurrentEpoch.get, currentOffset)) + partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -128,16 +128,16 @@ class ContinuousQueuedDataReader( /** * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when - * a new row arrives to the [[ContinuousPartitionReader]]. + * a new row arrives to the [[InputPartitionReader]]. */ - class DataReaderThread(schema: StructType) extends Thread( + class DataReaderThread extends Thread( s"continuous-reader--${context.partitionId()}--" + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { @volatile private[continuous] var failureReason: Throwable = _ - private val toUnsafe = UnsafeProjection.create(schema) override def run(): Unit = { TaskContext.setTaskContext(context) + val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) try { while (!shouldStop()) { if (!reader.next()) { @@ -149,9 +149,8 @@ class ContinuousQueuedDataReader( return } } - // `InternalRow#copy` may not be properly implemented, for safety we convert to unsafe row - // before copy here. - queue.put(ContinuousRow(toUnsafe(reader.get()).copy(), reader.getOffset)) + + queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) } } catch { case _: InterruptedException => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index a6cde2b8a710..551e07c3db86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -17,22 +17,24 @@ package org.apache.spark.sql.execution.streaming.continuous +import scala.collection.JavaConverters._ + import org.json4s.DefaultFormats import org.json4s.jackson.Serialization import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport { +class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -54,18 +56,18 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def fullSchema(): StructType = RateStreamProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start) - } + private var offset: Offset = _ - override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime) + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start + override def getStartOffset(): Offset = offset - val partitionStartMap = startOffset match { + override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { + val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => throw new IllegalArgumentException( @@ -88,12 +90,8 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin i, numPartitions, perPartitionRate) - }.toArray - } - - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - RateStreamContinuousReaderFactory + .asInstanceOf[InputPartition[InternalRow]] + }.asJava } override def commit(end: Offset): Unit = {} @@ -120,23 +118,33 @@ case class RateStreamContinuousInputPartition( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends InputPartition - -object RateStreamContinuousReaderFactory extends ContinuousPartitionReaderFactory { - override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { - val p = partition.asInstanceOf[RateStreamContinuousInputPartition] - new RateStreamContinuousPartitionReader( - p.startValue, p.startTimeMs, p.partitionIndex, p.increment, p.rowsPerSecond) + extends ContinuousInputPartition[InternalRow] { + + override def createContinuousReader( + offset: PartitionOffset): InputPartitionReader[InternalRow] = { + val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] + require(rateStreamOffset.partition == partitionIndex, + s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") + new RateStreamContinuousInputPartitionReader( + rateStreamOffset.currentValue, + rateStreamOffset.currentTimeMs, + partitionIndex, + increment, + rowsPerSecond) } + + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new RateStreamContinuousInputPartitionReader( + startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamContinuousPartitionReader( +class RateStreamContinuousInputPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousPartitionReader[InternalRow] { + extends ContinuousInputPartitionReader[InternalRow] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index 28ab2448a663..56bfefd91aaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.streaming.continuous import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp -import java.util.Calendar +import java.util.{Calendar, List => JList} import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.{DefaultFormats, NoTypeHints} @@ -33,26 +34,24 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.{Offset => _, _} +import org.apache.spark.sql.execution.streaming.{ContinuousRecordEndpoint, ContinuousRecordPartitionOffset, GetRecord} import org.apache.spark.sql.execution.streaming.sources.TextSocketReader import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** - * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials - * and debugging. This ContinuousReadSupport will *not* work in production applications due to - * multiple reasons, including no support for fault recovery. + * A ContinuousReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This ContinuousReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. * * The driver maintains a socket connection to the host-port, keeps the received messages in * buckets and serves the messages to the executors via a RPC endpoint. */ -class TextSocketContinuousReadSupport(options: DataSourceOptions) - extends ContinuousReadSupport with Logging { - +class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousReader with Logging { implicit val defaultFormats: DefaultFormats = DefaultFormats private val host: String = options.get("host").get() @@ -74,8 +73,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) @GuardedBy("this") private var currentOffset: Int = -1 - // Exposed for tests. - private[spark] var startOffset: TextSocketOffset = _ + private var startOffset: TextSocketOffset = _ private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -96,16 +94,16 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) TextSocketOffset(Serialization.read[List[Int]](json)) } - override def initialOffset(): Offset = { - startOffset = TextSocketOffset(List.fill(numPartitions)(0)) - startOffset + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { + this.startOffset = offset + .orElse(TextSocketOffset(List.fill(numPartitions)(0))) + .asInstanceOf[TextSocketOffset] + recordEndpoint.setStartOffsets(startOffset.offsets) } - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start) - } + override def getStartOffset: Offset = startOffset - override def fullSchema(): StructType = { + override def readSchema(): StructType = { if (includeTimestamp) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -113,10 +111,8 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) } } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] - .start.asInstanceOf[TextSocketOffset] - recordEndpoint.setStartOffsets(startOffset.offsets) + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) @@ -136,13 +132,10 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) startOffset.offsets.zipWithIndex.map { case (offset, i) => - TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp) - }.toArray - } + TextSocketContinuousInputPartition( + endpointName, i, offset, includeTimestamp): InputPartition[InternalRow] + }.asJava - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - TextSocketReaderFactory } override def commit(end: Offset): Unit = synchronized { @@ -197,7 +190,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) logWarning(s"Stream closed by $host:$port") return } - TextSocketContinuousReadSupport.this.synchronized { + TextSocketContinuousReader.this.synchronized { currentOffset += 1 val newData = (line, Timestamp.valueOf( @@ -228,30 +221,25 @@ case class TextSocketContinuousInputPartition( driverEndpointName: String, partitionId: Int, startOffset: Int, - includeTimestamp: Boolean) extends InputPartition - - -object TextSocketReaderFactory extends ContinuousPartitionReaderFactory { + includeTimestamp: Boolean) +extends InputPartition[InternalRow] { - override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { - val p = partition.asInstanceOf[TextSocketContinuousInputPartition] - new TextSocketContinuousPartitionReader( - p.driverEndpointName, p.partitionId, p.startOffset, p.includeTimestamp) - } + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new TextSocketContinuousInputPartitionReader(driverEndpointName, partitionId, startOffset, + includeTimestamp) } - /** * Continuous text socket input partition reader. * * Polls the driver endpoint for new records. */ -class TextSocketContinuousPartitionReader( +class TextSocketContinuousInputPartitionReader( driverEndpointName: String, partitionId: Int, startOffset: Int, includeTimestamp: Boolean) - extends ContinuousPartitionReader[InternalRow] { + extends ContinuousInputPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index a08411d746ab..967dbe24a370 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.DataWriter -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory} import org.apache.spark.util.Utils /** @@ -32,7 +31,7 @@ import org.apache.spark.util.Utils * * We keep repeating prev.compute() and writing new epochs until the query is shut down. */ -class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory) +class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) extends RDD[Unit](prev) { override val partitioner = prev.partitioner @@ -51,7 +50,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDat Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { val dataIterator = prev.compute(split, context) - dataWriter = writerFactory.createWriter( + dataWriter = writeTask.createDataWriter( context.partitionId(), context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 2238ce26e7b4..8877ebeb2673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,15 +82,15 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writeSupport: StreamingWriteSupport, - readSupport: ContinuousReadSupport, + writer: StreamWriter, + reader: ContinuousReader, query: ContinuousExecution, epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( - writeSupport, readSupport, query, startEpoch, session, env.rpcEnv) + writer, reader, query, startEpoch, session, env.rpcEnv) val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref @@ -115,8 +115,8 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writeSupport: StreamingWriteSupport, - readSupport: ContinuousReadSupport, + writer: StreamWriter, + reader: ContinuousReader, query: ContinuousExecution, startEpoch: Long, session: SparkSession, @@ -198,7 +198,7 @@ private[continuous] class EpochCoordinator( s"and is ready to be committed. Committing epoch $epoch.") // Sequencing is important here. We must commit to the writer before recording the commit // in the query, or we will end up dropping the commit if we restart in the middle. - writeSupport.commit(epoch, messages.toArray) + writer.commit(epoch, messages.toArray) query.commit(epoch) } @@ -220,7 +220,7 @@ private[continuous] class EpochCoordinator( partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochOffsets.size == numReaderPartitions) { logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") - query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq) + query.addOffset(epoch, reader, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 7ad21cc304e7..943c731a7052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** * The logical plan for writing data in a continuous stream. */ case class WriteToContinuousDataSource( - writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { + writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index c216b6138385..927d3a84e296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -26,21 +26,21 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** - * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. + * The physical plan for writing data into a continuous processing [[StreamWriter]]. */ -case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) +case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) extends SparkPlan with Logging { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writeSupport.createStreamingWriterFactory() + val writerFactory = writer.createWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source write support: $writeSupport. " + + logInfo(s"Start processing data source writer: $writer. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index adf52aba21a0..f81abdcc3711 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal @@ -31,8 +34,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -64,7 +67,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas addData(data.toTraversable) } - def fullSchema(): StructType = encoder.schema + def readSchema(): StructType = encoder.schema protected def logicalPlan: LogicalPlan @@ -77,7 +80,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReader with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -119,22 +122,24 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] + endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] + } + } - override def initialOffset: OffsetV2 = LongOffset(-1) + override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) - override def latestOffset(): OffsetV2 = { - if (currentOffset.offset == -1) null else currentOffset + override def getStartOffset: OffsetV2 = synchronized { + if (startOffset.offset == -1) null else startOffset } - override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + override def getEndOffset: OffsetV2 = synchronized { + if (endOffset.offset == -1) null else endOffset } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startOffset = sc.start.asInstanceOf[LongOffset] - val endOffset = sc.end.get.asInstanceOf[LongOffset] + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -151,15 +156,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamInputPartition(block) - }.toArray + new MemoryStreamInputPartition(block): InputPartition[InternalRow] + }.asJava } } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - MemoryStreamReaderFactory - } - private def generateDebugString( rows: Seq[UnsafeRow], startOrdinal: Int, @@ -200,12 +201,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition - -object MemoryStreamReaderFactory extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val records = partition.asInstanceOf[MemoryStreamInputPartition].records - new PartitionReader[InternalRow] { +class MemoryStreamInputPartition(records: Array[UnsafeRow]) + extends InputPartition[InternalRow] { + override def createPartitionReader(): InputPartitionReader[InternalRow] = { + new InputPartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 833e62f35ede..fd45ba509091 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -19,15 +19,16 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) - extends StreamingWriteSupport with Logging { +class ConsoleWriter(schema: StructType, options: DataSourceOptions) + extends StreamWriter with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) @@ -38,7 +39,7 @@ class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory + def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index dbcc4483e577..4a32217f149b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -17,22 +17,26 @@ package org.apache.spark.sql.execution.streaming.sources +import java.{util => ju} +import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.{Encoder, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.{Offset => _, _} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder} -import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** @@ -44,9 +48,7 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) - with ContinuousReadSupportProvider with ContinuousReadSupport { - + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -57,6 +59,9 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa @GuardedBy("this") private val records = Seq.fill(numPartitions)(new ListBuffer[A]) + @GuardedBy("this") + private var startOffset: ContinuousMemoryStreamOffset = _ + private val recordEndpoint = new ContinuousRecordEndpoint(records, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -70,8 +75,15 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } - override def initialOffset(): Offset = { - ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) + override def setStartOffset(start: Optional[Offset]): Unit = synchronized { + // Inferred initial offset is position 0 in each partition. + startOffset = start.orElse { + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) + }.asInstanceOf[ContinuousMemoryStreamOffset] + } + + override def getStartOffset: Offset = synchronized { + startOffset } override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { @@ -86,40 +98,34 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] - .start.asInstanceOf[ContinuousMemoryStreamOffset] + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) startOffset.partitionNums.map { - case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index) - }.toArray + case (part, index) => + new ContinuousMemoryStreamInputPartition( + endpointName, part, index): InputPartition[InternalRow] + }.toList.asJava } } - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - ContinuousMemoryStreamReaderFactory - } - override def stop(): Unit = { if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) } override def commit(end: Offset): Unit = {} - // ContinuousReadSupportProvider implementation + // ContinuousReadSupport implementation // This is necessary because of how StreamTest finds the source for AddDataMemory steps. - override def createContinuousReadSupport( + def createContinuousReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = this + options: DataSourceOptions): ContinuousReader = { + this + } } object ContinuousMemoryStream { @@ -135,16 +141,12 @@ object ContinuousMemoryStream { /** * An input partition for continuous memory stream. */ -case class ContinuousMemoryStreamInputPartition( +class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, - startOffset: Int) extends InputPartition - -object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFactory { - override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { - val p = partition.asInstanceOf[ContinuousMemoryStreamInputPartition] - new ContinuousMemoryStreamPartitionReader(p.driverEndpointName, p.partition, p.startOffset) - } + startOffset: Int) extends InputPartition[InternalRow] { + override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = + new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) } /** @@ -152,10 +154,10 @@ object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFact * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamPartitionReader( +class ContinuousMemoryStreamInputPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousPartitionReader[InternalRow] { + startOffset: Int) extends ContinuousInputPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala similarity index 82% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index 4218fd51ad20..e8ce21cc1204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -22,9 +22,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -37,21 +37,20 @@ import org.apache.spark.sql.types.StructType * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriteSupportProvider[T]( +case class ForeachWriterProvider[T]( writer: ForeachWriter[T], - converter: Either[ExpressionEncoder[T], InternalRow => T]) - extends StreamingWriteSupportProvider { + converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { - override def createStreamingWriteSupport( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new StreamingWriteSupport { + options: DataSourceOptions): StreamWriter = { + new StreamWriter { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { val rowConverter: InternalRow => T = converter match { case Left(enc) => val boundEnc = enc.resolveAndBind( @@ -69,16 +68,16 @@ case class ForeachWriteSupportProvider[T]( } } -object ForeachWriteSupportProvider { +object ForeachWriterProvider { def apply[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { + encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { writer match { case pythonWriter: PythonForeachWriter => - new ForeachWriteSupportProvider[UnsafeRow]( + new ForeachWriterProvider[UnsafeRow]( pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) case _ => - new ForeachWriteSupportProvider[T](writer, Left(encoder)) + new ForeachWriterProvider[T](writer, Left(encoder)) } } } @@ -86,8 +85,8 @@ object ForeachWriteSupportProvider { case class ForeachWriterFactory[T]( writer: ForeachWriter[T], rowConverter: InternalRow => T) - extends StreamingDataWriterFactory { - override def createWriter( + extends DataWriterFactory[InternalRow] { + override def createDataWriter( partitionId: Int, taskId: Long, epochId: Long): ForeachDataWriter[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala deleted file mode 100644 index 9f88416871f8..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} - -/** - * A [[BatchWriteSupport]] used to hook V2 stream writers into a microbatch plan. It implements - * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped - * streaming write support. - */ -class MicroBatchWritSupport(eppchId: Long, val writeSupport: StreamingWriteSupport) - extends BatchWriteSupport { - - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writeSupport.commit(eppchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = { - writeSupport.abort(eppchId, messages) - } - - override def createBatchWriterFactory(): DataWriterFactory = { - new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory()) - } -} - -class MicroBatchWriterFactory(epochId: Long, streamingWriterFactory: StreamingDataWriterFactory) - extends DataWriterFactory { - - override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { - streamingWriterFactory.createWriter(partitionId, taskId, epochId) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala similarity index 50% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index 90680ea38fbd..2d43a7bb7787 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -17,15 +17,21 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter -// A special `MicroBatchReadSupport` that can get latestOffset with a start offset. -trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport { - - override def latestOffset(): Offset = { - throw new IllegalAccessException( - "latestOffset should not be called for RateControlMicroBatchReadSupport") +/** + * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped + * streaming writer. + */ +class MicroBatchWriter(batchId: Long, val writer: StreamWriter) extends DataSourceWriter { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) } - def latestOffset(start: Offset): Offset + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index ac3c71cc222b..f26e11d842b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -21,18 +21,17 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[BatchWriteSupport]] on the driver. + * to a [[DataSourceWriter]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends StreamingDataWriterFactory { - override def createWriter( +case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] { + override def createDataWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala similarity index 78% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index f5364047adff..9e0d95493216 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -19,24 +19,27 @@ package org.apache.spark.sql.execution.streaming.sources import java.io._ import java.nio.charset.StandardCharsets +import java.util.Optional import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ + import org.apache.commons.io.IOUtils import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReadSupport with Logging { +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -103,30 +106,38 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca @volatile private var lastTimeMs: Long = creationTimeMs - override def initialOffset(): Offset = LongOffset(0L) + private var start: LongOffset = _ + private var end: LongOffset = _ - override def latestOffset(): Offset = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - } + override def readSchema(): StructType = SCHEMA - override def deserializeOffset(json: String): Offset = { - LongOffset(json.toLong) + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] } - override def fullSchema(): StructType = SCHEMA + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startSeconds = sc.start.asInstanceOf[LongOffset].offset - val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset + override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") if (endSeconds > maxSeconds) { throw new ArithmeticException("Integer overflow. Max offset with " + @@ -142,7 +153,7 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") if (rangeStart == rangeEnd) { - return Array.empty + return List.empty.asJava } val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) @@ -159,11 +170,8 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca (0 until numPartitions).map { p => new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - }.toArray - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - RateStreamMicroBatchReaderFactory + : InputPartition[InternalRow] + }.toList.asJava } override def commit(end: Offset): Unit = {} @@ -175,29 +183,26 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -case class RateStreamMicroBatchInputPartition( +class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartition + relativeMsPerValue: Double) extends InputPartition[InternalRow] { -object RateStreamMicroBatchReaderFactory extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val p = partition.asInstanceOf[RateStreamMicroBatchInputPartition] - new RateStreamMicroBatchPartitionReader(p.partitionId, p.numPartitions, p.rangeStart, - p.rangeEnd, p.localStartTimeMs, p.relativeMsPerValue) - } + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, + localStartTimeMs, relativeMsPerValue) } -class RateStreamMicroBatchPartitionReader( +class RateStreamMicroBatchInputPartitionReader( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends PartitionReader[InternalRow] { + relativeMsPerValue: Double) extends InputPartitionReader[InternalRow] { private var count: Long = 0 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6942dfbfe0ec..6bdd492f0cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util.Optional + import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} import org.apache.spark.sql.types._ /** @@ -39,12 +42,13 @@ import org.apache.spark.sql.types._ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister { + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { import RateStreamProvider._ - override def createMicroBatchReadSupport( + override def createMicroBatchReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { + options: DataSourceOptions): MicroBatchReader = { if (options.get(ROWS_PER_SECOND).isPresent) { val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong if (rowsPerSecond <= 0) { @@ -70,14 +74,17 @@ class RateStreamProvider extends DataSourceV2 } } - new RateStreamMicroBatchReadSupport(options, checkpointLocation) + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) } - override def createContinuousReadSupport( + override def createContinuousReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = { - new RateStreamContinuousReadSupport(options) - } + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) override def shortName(): String = "rate" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index c50dc7bcb8da..cb76e8650339 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -42,15 +42,13 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider - with MemorySinkBase with Logging { - - override def createStreamingWriteSupport( +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new MemoryStreamingWriteSupport(this, mode, schema) + options: DataSourceOptions): StreamWriter = { + new MemoryStreamWriter(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -122,13 +120,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryStreamingWriteSupport( - val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamingWriteSupport { +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) + extends StreamWriter { - override def createStreamingWriterFactory: MemoryWriterFactory = { - MemoryWriterFactory(outputMode, schema) - } + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -143,19 +138,13 @@ class MemoryStreamingWriteSupport( } case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) - extends DataWriterFactory with StreamingDataWriterFactory { + extends DataWriterFactory[InternalRow] { - override def createWriter( - partitionId: Int, - taskId: Long): DataWriter[InternalRow] = { - new MemoryDataWriter(partitionId, outputMode, schema) - } - - override def createWriter( + override def createDataWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { - createWriter(partitionId, taskId) + new MemoryDataWriter(partitionId, outputMode, schema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index b2a573eae504..874c479db95d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.{Calendar, List => JList, Locale, Optional} import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} @@ -31,15 +32,16 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} -import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport +import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReader import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String +// Shared object for micro-batch and continuous reader object TextSocketReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: @@ -48,12 +50,14 @@ object TextSocketReader { } /** - * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials - * and debugging. This MicroBatchReadSupport will *not* work in production applications due to - * multiple reasons, including no support for fault recovery. + * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This MicroBatchReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReadSupport(options: DataSourceOptions) - extends MicroBatchReadSupport with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { + + private var startOffset: Offset = _ + private var endOffset: Offset = _ private val host: String = options.get("host").get() private val port: Int = options.get("port").get().toInt @@ -99,7 +103,7 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) logWarning(s"Stream closed by $host:$port") return } - TextSocketMicroBatchReadSupport.this.synchronized { + TextSocketMicroBatchReader.this.synchronized { val newData = ( UTF8String.fromString(line), DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) @@ -116,15 +120,24 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) readThread.start() } - override def initialOffset(): Offset = LongOffset(-1L) + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { + startOffset = start.orElse(LongOffset(-1L)) + endOffset = end.orElse(currentOffset) + } - override def latestOffset(): Offset = currentOffset + override def getStartOffset(): Offset = { + Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) + } + + override def getEndOffset(): Offset = { + Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) + } override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def fullSchema(): StructType = { + override def readSchema(): StructType = { if (options.getBoolean("includeTimestamp", false)) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -132,14 +145,12 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) } } - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + assert(startOffset != null && endOffset != null, + "start offset and end offset should already be set before create read tasks.") - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1 - val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1 + val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 + val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -161,29 +172,26 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) slices(idx % numPartitions).append(r) } - slices.map(TextSocketInputPartition) - } + (0 until numPartitions).map { i => + val slice = slices(i) + new InputPartition[InternalRow] { + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new InputPartitionReader[InternalRow] { + private var currentIdx = -1 - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - new PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val slice = partition.asInstanceOf[TextSocketInputPartition].slice - new PartitionReader[InternalRow] { - private var currentIdx = -1 + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def get(): InternalRow = { + InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) + } - override def get(): InternalRow = { - InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) + override def close(): Unit = {} } - - override def close(): Unit = {} - } } - } + }.toList.asJava } override def commit(end: Offset): Unit = synchronized { @@ -219,11 +227,8 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) override def toString: String = s"TextSocketV2[host: $host, port: $port]" } -case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition - class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupportProvider with ContinuousReadSupportProvider - with DataSourceRegister with Logging { + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + @@ -243,18 +248,27 @@ class TextSocketSourceProvider extends DataSourceV2 } } - override def createMicroBatchReadSupport( + override def createMicroBatchReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { + options: DataSourceOptions): MicroBatchReader = { checkParameters(options) - new TextSocketMicroBatchReadSupport(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + new TextSocketMicroBatchReader(options) } - override def createContinuousReadSupport( + override def createContinuousReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = { + options: DataSourceOptions): ContinuousReader = { checkParameters(options) - new TextSocketContinuousReadSupport(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + new TextSocketContinuousReader(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 39e9e1ad426b..ef8dc3a325a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.Locale +import java.util.{Locale, Optional} import scala.collection.JavaConverters._ @@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,21 +172,19 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case s: MicroBatchReadSupportProvider => - var tempReadSupport: MicroBatchReadSupport = null + case s: MicroBatchReadSupport => + var tempReader: MicroBatchReader = null val schema = try { - val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath - tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) - } else { - s.createMicroBatchReadSupport(tmpCheckpointPath, options) - } - tempReadSupport.fullSchema() + tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + tempReader.readSchema() } finally { // Stop tempReader to avoid side-effect thing - if (tempReadSupport != null) { - tempReadSupport.stop() - tempReadSupport = null + if (tempReader != null) { + tempReader.stop() + tempReader = null } } Dataset.ofRows( @@ -194,28 +192,16 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo StreamingRelationV2( s, source, extraOptions.toMap, schema.toAttributes, v1Relation)(sparkSession)) - case s: ContinuousReadSupportProvider => - var tempReadSupport: ContinuousReadSupport = null - val schema = try { - val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath - tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) - } else { - s.createContinuousReadSupport(tmpCheckpointPath, options) - } - tempReadSupport.fullSchema() - } finally { - // Stop tempReader to avoid side-effect thing - if (tempReadSupport != null) { - tempReadSupport.stop() - tempReadSupport = null - } - } + case s: ContinuousReadSupport => + val tempReader = s.createContinuousReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - schema.toAttributes, v1Relation)(sparkSession)) + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 7866e4f70f14..3b9a56ffdde4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.{InterfaceStability, Since} import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -270,7 +270,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -299,8 +299,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") val sink = ds.newInstance() match { - case w: StreamingWriteSupportProvider - if !disabledSources.contains(w.getClass.getCanonicalName) => w + case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index cd52d991d55c..25bb05212d66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -256,7 +256,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => + case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 5602310219a7..e4cead9df429 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,71 +24,29 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { +public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - public class ReadSupport extends JavaSimpleReadSupport { - @Override - public ScanConfigBuilder newScanConfigBuilder() { - return new AdvancedScanConfigBuilder(); - } - - @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - Filter[] filters = ((AdvancedScanConfigBuilder) config).filters; - List res = new ArrayList<>(); - - Integer lowerBound = null; - for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.attribute()) && f.value() instanceof Integer) { - lowerBound = (Integer) f.value(); - break; - } - } - } - - if (lowerBound == null) { - res.add(new JavaRangeInputPartition(0, 5)); - res.add(new JavaRangeInputPartition(5, 10)); - } else if (lowerBound < 4) { - res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); - res.add(new JavaRangeInputPartition(5, 10)); - } else if (lowerBound < 9) { - res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); - } - - return res.stream().toArray(InputPartition[]::new); - } - - @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema; - return new AdvancedReaderFactory(requiredSchema); - } - } - - public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig, - SupportsPushDownFilters, SupportsPushDownRequiredColumns { + public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, + SupportsPushDownFilters { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); public Filter[] filters = new Filter[0]; @Override - public void pruneColumns(StructType requiredSchema) { - this.requiredSchema = requiredSchema; + public StructType readSchema() { + return requiredSchema; } @Override - public StructType readSchema() { - return requiredSchema; + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; } @Override @@ -121,54 +79,79 @@ public Filter[] pushedFilters() { } @Override - public ScanConfig build() { - return this; + public List> planInputPartitions() { + List> res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); + } else if (lowerBound < 4) { + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); + } else if (lowerBound < 9) { + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); + } + + return res; } } - static class AdvancedReaderFactory implements PartitionReaderFactory { - StructType requiredSchema; + static class JavaAdvancedInputPartition implements InputPartition, + InputPartitionReader { + private int start; + private int end; + private StructType requiredSchema; - AdvancedReaderFactory(StructType requiredSchema) { + JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { + this.start = start; + this.end = end; this.requiredSchema = requiredSchema; } @Override - public PartitionReader createReader(InputPartition partition) { - JavaRangeInputPartition p = (JavaRangeInputPartition) partition; - return new PartitionReader() { - private int current = p.start - 1; - - @Override - public boolean next() throws IOException { - current += 1; - return current < p.end; - } + public InputPartitionReader createPartitionReader() { + return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); + } - @Override - public InternalRow get() { - Object[] values = new Object[requiredSchema.size()]; - for (int i = 0; i < values.length; i++) { - if ("i".equals(requiredSchema.apply(i).name())) { - values[i] = current; - } else if ("j".equals(requiredSchema.apply(i).name())) { - values[i] = -current; - } - } - return new GenericInternalRow(values); + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public InternalRow get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = start; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -start; } + } + return new GenericInternalRow(values); + } - @Override - public void close() throws IOException { + @Override + public void close() throws IOException { - } - }; } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java new file mode 100644 index 000000000000..97d6176d0255 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceReader, SupportsScanColumnarBatch { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> planBatchInputPartitions() { + return java.util.Arrays.asList( + new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); + } + } + + static class JavaBatchInputPartition + implements InputPartition, InputPartitionReader { + private int start; + private int end; + + private static final int BATCH_SIZE = 20; + + private OnHeapColumnVector i; + private OnHeapColumnVector j; + private ColumnarBatch batch; + + JavaBatchInputPartition(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public InputPartitionReader createPartitionReader() { + this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + this.batch = new ColumnarBatch(vectors); + return this; + } + + @Override + public boolean next() { + i.reset(); + j.reset(); + int count = 0; + while (start < end && count < BATCH_SIZE) { + i.putInt(count, start); + j.putInt(count, -start); + start += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + } + + + @Override + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java deleted file mode 100644 index 28a933039831..000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; - - -public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - - class ReadSupport extends JavaSimpleReadSupport { - - @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - InputPartition[] partitions = new InputPartition[2]; - partitions[0] = new JavaRangeInputPartition(0, 50); - partitions[1] = new JavaRangeInputPartition(50, 90); - return partitions; - } - - @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - return new ColumnarReaderFactory(); - } - } - - static class ColumnarReaderFactory implements PartitionReaderFactory { - private static final int BATCH_SIZE = 20; - - @Override - public boolean supportColumnarReads(InputPartition partition) { - return true; - } - - @Override - public PartitionReader createReader(InputPartition partition) { - throw new UnsupportedOperationException(""); - } - - @Override - public PartitionReader createColumnarReader(InputPartition partition) { - JavaRangeInputPartition p = (JavaRangeInputPartition) partition; - OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - ColumnVector[] vectors = new ColumnVector[2]; - vectors[0] = i; - vectors[1] = j; - ColumnarBatch batch = new ColumnarBatch(vectors); - - return new PartitionReader() { - private int current = p.start; - - @Override - public boolean next() throws IOException { - i.reset(); - j.reset(); - int count = 0; - while (current < p.end && count < BATCH_SIZE) { - i.putInt(count, current); - j.putInt(count, -current); - current += 1; - count += 1; - } - - if (count == 0) { - return false; - } else { - batch.setNumRows(count); - return true; - } - } - - @Override - public ColumnarBatch get() { - return batch; - } - - @Override - public void close() throws IOException { - batch.close(); - } - }; - } - } - - @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 18a11dde8219..2d21324f5ece 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -19,34 +19,38 @@ import java.io.IOException; import java.util.Arrays; +import java.util.List; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.*; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; +import org.apache.spark.sql.types.StructType; -public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider { +public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { - class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning { + class Reader implements DataSourceReader, SupportsReportPartitioning { + private final StructType schema = new StructType().add("a", "int").add("b", "int"); @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - InputPartition[] partitions = new InputPartition[2]; - partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}); - partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}); - return partitions; + public StructType readSchema() { + return schema; } @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - return new SpecificReaderFactory(); + public List> planInputPartitions() { + return java.util.Arrays.asList( + new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); } @Override - public Partitioning outputPartitioning(ScanConfig config) { + public Partitioning outputPartitioning() { return new MyPartitioning(); } } @@ -62,53 +66,50 @@ public int numPartitions() { public boolean satisfy(Distribution distribution) { if (distribution instanceof ClusteredDistribution) { String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; - return Arrays.asList(clusteredCols).contains("i"); + return Arrays.asList(clusteredCols).contains("a"); } return false; } } - static class SpecificInputPartition implements InputPartition { - int[] i; - int[] j; + static class SpecificInputPartition implements InputPartition, + InputPartitionReader { + + private int[] i; + private int[] j; + private int current = -1; SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; } - } - static class SpecificReaderFactory implements PartitionReaderFactory { + @Override + public boolean next() throws IOException { + current += 1; + return current < i.length; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {i[current], j[current]}); + } + + @Override + public void close() throws IOException { + + } @Override - public PartitionReader createReader(InputPartition partition) { - SpecificInputPartition p = (SpecificInputPartition) partition; - return new PartitionReader() { - private int current = -1; - - @Override - public boolean next() throws IOException { - current += 1; - return current < p.i.length; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {p.i[current], p.j[current]}); - } - - @Override - public void close() throws IOException { - - } - }; + public InputPartitionReader createPartitionReader() { + return this; } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index cc9ac04a0dad..6fd6a44d2c4d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,39 +17,43 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; +import java.util.List; + +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider { +public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupport { - class ReadSupport extends JavaSimpleReadSupport { + class Reader implements DataSourceReader { private final StructType schema; - ReadSupport(StructType schema) { + Reader(StructType schema) { this.schema = schema; } @Override - public StructType fullSchema() { + public StructType readSchema() { return schema; } @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - return new InputPartition[0]; + public List> planInputPartitions() { + return java.util.Collections.emptyList(); } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + public DataSourceReader createReader(DataSourceOptions options) { throw new IllegalArgumentException("requires a user-supplied schema"); } @Override - public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { - return new ReadSupport(schema); + public DataSourceReader createReader(StructType schema, DataSourceOptions options) { + return new Reader(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2cdbba84ec4a..274dc3745bcf 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,26 +17,72 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.types.StructType; + +public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceReader { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> planInputPartitions() { + return java.util.Arrays.asList( + new JavaSimpleInputPartition(0, 5), + new JavaSimpleInputPartition(5, 10)); + } + } + + static class JavaSimpleInputPartition implements InputPartition, + InputPartitionReader { -public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { + private int start; + private int end; - class ReadSupport extends JavaSimpleReadSupport { + JavaSimpleInputPartition(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public InputPartitionReader createPartitionReader() { + return new JavaSimpleInputPartition(start - 1, end); + } @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - InputPartition[] partitions = new InputPartition[2]; - partitions[0] = new JavaRangeInputPartition(0, 5); - partitions[1] = new JavaRangeInputPartition(5, 10); - return partitions; + public boolean next() { + start += 1; + return start < end; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {start, -start}); + } + + @Override + public void close() throws IOException { + } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java deleted file mode 100644 index 685f9b9747e8..000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.StructType; - -abstract class JavaSimpleReadSupport implements BatchReadSupport { - - @Override - public StructType fullSchema() { - return new StructType().add("i", "int").add("j", "int"); - } - - @Override - public ScanConfigBuilder newScanConfigBuilder() { - return new JavaNoopScanConfigBuilder(fullSchema()); - } - - @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - return new JavaSimpleReaderFactory(); - } -} - -class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { - - private StructType schema; - - JavaNoopScanConfigBuilder(StructType schema) { - this.schema = schema; - } - - @Override - public ScanConfig build() { - return this; - } - - @Override - public StructType readSchema() { - return schema; - } -} - -class JavaSimpleReaderFactory implements PartitionReaderFactory { - - @Override - public PartitionReader createReader(InputPartition partition) { - JavaRangeInputPartition p = (JavaRangeInputPartition) partition; - return new PartitionReader() { - private int current = p.start - 1; - - @Override - public boolean next() throws IOException { - current += 1; - return current < p.end; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {current, -current}); - } - - @Override - public void close() throws IOException { - - } - }; - } -} - -class JavaRangeInputPartition implements InputPartition { - int start; - int end; - - JavaRangeInputPartition(int start, int end) { - this.start = start; - this.end = end; - } -} diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a36b0cfa6ff1..46b38bed1c0f 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider +org.apache.spark.sql.streaming.sources.FakeWrite org.apache.spark.sql.streaming.sources.FakeNoWrite -org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback +org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 61857365ac98..7bb2cf59f5ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -43,7 +43,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("streaming writer") { val sink = new MemorySinkV2 - val writeSupport = new MemoryStreamingWriteSupport( + val writeSupport = new MemoryStreamWriter( sink, OutputMode.Append(), new StructType().add("i", "int")) writeSupport.commit(0, Array( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala index 5884380271f0..55acf2ba28d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.{StreamTest, Trigger} -class ConsoleWriteSupportSuite extends StreamTest { +class ConsoleWriterSuite extends StreamTest { import testImplicits._ test("microbatch - default") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index dd74af873c2e..5ca13b89735b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util.Optional import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -41,7 +43,7 @@ class RateSourceSuite extends StreamTest { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source }.head rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) @@ -54,10 +56,10 @@ class RateSourceSuite extends StreamTest { test("microbatch in registry") { withTempDir { temp => DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupportProvider => - val readSupport = ds.createMicroBatchReadSupport( - temp.getCanonicalPath, DataSourceOptions.empty()) - assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport]) + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader( + Optional.empty(), temp.getCanonicalPath, DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) case _ => throw new IllegalStateException("Could not find read support for rate") } @@ -67,7 +69,7 @@ class RateSourceSuite extends StreamTest { test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupportProvider => + case ds: MicroBatchReadSupport => assert(ds.isInstanceOf[RateStreamProvider]) case _ => throw new IllegalStateException("Could not find read support for rate") @@ -139,19 +141,30 @@ class RateSourceSuite extends StreamTest { ) } + test("microbatch - set offset") { + withTempDir { temp => + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp.getCanonicalPath) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + } + test("microbatch - infer offsets") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( + val reader = new RateStreamMicroBatchReader( new DataSourceOptions( Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), temp.getCanonicalPath) - readSupport.clock.asInstanceOf[ManualClock].advance(100000) - val startOffset = readSupport.initialOffset() - startOffset match { + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { case r: LongOffset => assert(r.offset === 0L) case _ => throw new IllegalStateException("unexpected offset type") } - readSupport.latestOffset() match { + reader.getEndOffset() match { case r: LongOffset => assert(r.offset >= 100) case _ => throw new IllegalStateException("unexpected offset type") } @@ -160,16 +173,15 @@ class RateSourceSuite extends StreamTest { test("microbatch - predetermined batch size") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( + val reader = new RateStreamMicroBatchReader( new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createReaderFactory(config) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.planInputPartitions() assert(tasks.size == 1) - val dataReader = readerFactory.createReader(tasks(0)) + val dataReader = tasks.get(0).createPartitionReader() val data = ArrayBuffer[InternalRow]() while (dataReader.next()) { data.append(dataReader.get()) @@ -180,25 +192,24 @@ class RateSourceSuite extends StreamTest { test("microbatch - data read") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( + val reader = new RateStreamMicroBatchReader( new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createReaderFactory(config) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.planInputPartitions() assert(tasks.size == 11) - val readData = tasks - .map(readerFactory.createReader) + val readData = tasks.asScala + .map(_.createPartitionReader()) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[InternalRow]() while (reader.next()) buf.append(reader.get()) buf } - assert(readData.map(_.getLong(1)).sorted === 0.until(33).toArray) + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) } } @@ -309,44 +320,41 @@ class RateSourceSuite extends StreamTest { } test("user-specified schema given") { - val exception = intercept[UnsupportedOperationException] { + val exception = intercept[AnalysisException] { spark.readStream .format("rate") .schema(spark.range(1).schema) .load() } assert(exception.getMessage.contains( - "rate source does not support user-specified schema")) + "rate source does not support a user-specified schema")) } test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupportProvider => - val readSupport = ds.createContinuousReadSupport( - "", DataSourceOptions.empty()) - assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport]) + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) case _ => throw new IllegalStateException("Could not find read support for continuous rate") } } test("continuous data") { - val readSupport = new RateStreamContinuousReadSupport( + val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createContinuousReaderFactory(config) + reader.setStartOffset(Optional.empty()) + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[InternalRow]() - tasks.foreach { + tasks.asScala.foreach { case t: RateStreamContinuousInputPartition => - val startTimeMs = readSupport.initialOffset() + val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = readerFactory.createReader(t) - .asInstanceOf[RateStreamContinuousPartitionReader] + val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 409156e5ebc7..48e5cf75bf8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -21,6 +21,7 @@ import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp +import java.util.Optional import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -33,8 +34,8 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider} -import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -48,9 +49,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread.join() serverThread = null } + if (batchReader != null) { + batchReader.stop() + batchReader = null + } } private var serverThread: ServerThread = null + private var batchReader: MicroBatchReader = null case class AddSocketData(data: String*) extends AddData { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { @@ -59,7 +65,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "Cannot add data when there is no query for finding the active socket source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source + case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source } if (sources.isEmpty) { throw new Exception( @@ -85,7 +91,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupportProvider => + case ds: MicroBatchReadSupport => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => throw new IllegalStateException("Could not find socket source") @@ -175,16 +181,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map.empty[String, String].asJava)) + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map.empty[String, String].asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("port" -> "1234").asJava)) } } @@ -193,7 +199,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { val a = new DataSourceOptions(params.asJava) - provider.createMicroBatchReadSupport("", a) + provider.createMicroBatchReader(Optional.empty(), "", a) } } @@ -203,12 +209,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("name", StringType) :: StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") - val exception = intercept[UnsupportedOperationException] { - provider.createMicroBatchReadSupport( - userSpecifiedSchema, "", new DataSourceOptions(params.asJava)) + val exception = intercept[AnalysisException] { + provider.createMicroBatchReader( + Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) } assert(exception.getMessage.contains( - "socket source does not support user-specified schema")) + "socket source does not support a user-specified schema")) } test("input row metrics") { @@ -299,27 +305,25 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val readSupport = new TextSocketContinuousReadSupport( + val reader = new TextSocketContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - - val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() - val tasks = readSupport.planInputPartitions(scanConfig) + reader.setStartOffset(Optional.empty()) + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val numRecords = 10 val data = scala.collection.mutable.ListBuffer[Int]() val offsets = scala.collection.mutable.ListBuffer[Int]() - val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) import org.scalatest.time.SpanSugar._ failAfter(5 seconds) { // inject rows, read and check the data and offsets for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - tasks.foreach { + tasks.asScala.foreach { case t: TextSocketContinuousInputPartition => - val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] + val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] for (i <- 0 until numRecords / 2) { r.next() offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset) @@ -335,15 +339,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before data.clear() case _ => throw new IllegalStateException("Unexpected task type") } - assert(readSupport.startOffset.offsets == List(3, 3)) - readSupport.commit(TextSocketOffset(List(5, 5))) - assert(readSupport.startOffset.offsets == List(5, 5)) + assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(3, 3)) + reader.commit(TextSocketOffset(List(5, 5))) + assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(5, 5)) } def commitOffset(partition: Int, offset: Int): Unit = { - val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset) - readSupport.commit(TextSocketOffset(offsetsToCommit)) - assert(readSupport.startOffset.offsets == offsetsToCommit) + val offsetsToCommit = reader.getStartOffset.asInstanceOf[TextSocketOffset] + .offsets.updated(partition, offset) + reader.commit(TextSocketOffset(offsetsToCommit)) + assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == offsetsToCommit) } } @@ -351,13 +356,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val readSupport = new TextSocketContinuousReadSupport( + val reader = new TextSocketContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - - readSupport.startOffset = TextSocketOffset(List(5, 5)) + reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) + // ok to commit same offset + reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) assertThrows[IllegalStateException] { - readSupport.commit(TextSocketOffset(List(6, 6))) + reader.commit(TextSocketOffset(List(6, 6))) } } @@ -365,12 +371,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val readSupport = new TextSocketContinuousReadSupport( + val reader = new TextSocketContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "includeTimestamp" -> "true", "port" -> serverThread.port.toString).asJava)) - val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() - val tasks = readSupport.planInputPartitions(scanConfig) + reader.setStartOffset(Optional.empty()) + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val numRecords = 4 @@ -378,10 +384,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) - tasks.foreach { + tasks.asScala.foreach { case t: TextSocketContinuousInputPartition => - val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] + val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] for (i <- 0 until numRecords / 2) { r.next() assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index f6c3e0ce82e3..12beca257a0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources.v2 +import java.util.{ArrayList, List => JList} + import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException @@ -36,21 +38,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ - private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] - }.head - } - - private def getJavaScanConfig( - query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] - }.head - } - test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -63,6 +50,18 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + } + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -71,58 +70,58 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('j) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q1) - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + val reader = getReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) } else { - val config = getJavaScanConfig(q1) - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + val reader = getJavaReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) } val q2 = df.filter('i > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q2) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i", "j")) + val reader = getReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) } else { - val config = getJavaScanConfig(q2) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i", "j")) + val reader = getJavaReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) } val q3 = df.select('i).filter('i > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q3) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i")) + val reader = getReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) } else { - val config = getJavaScanConfig(q3) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i")) + val reader = getJavaReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) } val q4 = df.select('j).filter('j < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q4) + val reader = getReader(q4) // 'j < 10 is not supported by the testing data source. - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) } else { - val config = getJavaScanConfig(q4) + val reader = getJavaReader(q4) // 'j < 10 is not supported by the testing data source. - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) } } } } test("columnar batch scan implementation") { - Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => + Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 90).map(i => Row(i, -i))) @@ -154,25 +153,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df = spark.read.format(cls.getName).load() checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - val groupByColA = df.groupBy('i).agg(sum('j)) + val groupByColA = df.groupBy('a).agg(sum('b)) checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) assert(groupByColA.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColAB = df.groupBy('i, 'j).agg(count("*")) + val groupByColAB = df.groupBy('a, 'b).agg(count("*")) checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) assert(groupByColAB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColB = df.groupBy('j).agg(sum('i)) + val groupByColB = df.groupBy('b).agg(sum('a)) checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) assert(groupByColB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isDefined) - val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) + val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e @@ -273,30 +272,36 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23301: column pruning with arbitrary expressions") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() val q1 = df.select('i + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) - val config1 = getScanConfig(q1) - assert(config1.requiredSchema.fieldNames === Seq("i")) + val reader1 = getReader(q1) + assert(reader1.requiredSchema.fieldNames === Seq("i")) val q2 = df.select(lit(1)) checkAnswer(q2, (0 until 10).map(i => Row(1))) - val config2 = getScanConfig(q2) - assert(config2.requiredSchema.isEmpty) + val reader2 = getReader(q2) + assert(reader2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning val q3 = df.filter('j === -1).select('j * 2) checkAnswer(q3, Row(-2)) - val config3 = getScanConfig(q3) - assert(config3.filters.isEmpty) - assert(config3.requiredSchema.fieldNames === Seq("j")) + val reader3 = getReader(q3) + assert(reader3.filters.isEmpty) + assert(reader3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. val q4 = df.sort('i).limit(1).select('i + 1) checkAnswer(q4, Row(1)) - val config4 = getScanConfig(q4) - assert(config4.requiredSchema.fieldNames === Seq("i")) + val reader4 = getReader(q4) + assert(reader4.requiredSchema.fieldNames === Seq("i")) } test("SPARK-23315: get output from canonicalized data source v2 related plans") { @@ -319,291 +324,240 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } +class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { -case class RangeInputPartition(start: Int, end: Int) extends InputPartition - -case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { - override def build(): ScanConfig = this -} - -object SimpleReaderFactory extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val RangeInputPartition(start, end) = partition - new PartitionReader[InternalRow] { - private var current = start - 1 - - override def next(): Boolean = { - current += 1 - current < end - } - - override def get(): InternalRow = InternalRow(current, -current) + class Reader extends DataSourceReader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def close(): Unit = {} + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -abstract class SimpleReadSupport extends BatchReadSupport { - override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. +class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - override def newScanConfigBuilder(): ScanConfigBuilder = { - NoopScanConfigBuilder(fullSchema()) - } + class Reader extends DataSourceReader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - SimpleReaderFactory + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) + } } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } +class SimpleInputPartition(start: Int, end: Int) + extends InputPartition[InternalRow] + with InputPartitionReader[InternalRow] { + private var current = start - 1 -class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new SimpleInputPartition(start, end) - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - Array(RangeInputPartition(0, 5)) - } + override def next(): Boolean = { + current += 1 + current < end } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def get(): InternalRow = InternalRow(current, -current) + + override def close(): Unit = {} } -// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark -// tests still pass. -class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) - } - } + class Reader extends DataSourceReader + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } -} + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } -class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported + } - class ReadSupport extends SimpleReadSupport { - override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() + override def pushedFilters(): Array[Filter] = filters - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters + override def readSchema(): StructType = { + requiredSchema + } + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { val lowerBound = filters.collectFirst { case GreaterThan("i", v: Int) => v } - val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] + val res = new ArrayList[InputPartition[InternalRow]] if (lowerBound.isEmpty) { - res.append(RangeInputPartition(0, 5)) - res.append(RangeInputPartition(5, 10)) + res.add(new AdvancedInputPartition(0, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.append(RangeInputPartition(lowerBound.get + 1, 5)) - res.append(RangeInputPartition(5, 10)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.append(RangeInputPartition(lowerBound.get + 1, 10)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) } - res.toArray - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema - new AdvancedReaderFactory(requiredSchema) + res } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { +class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) + extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { - var requiredSchema = new StructType().add("i", "int").add("j", "int") - var filters = Array.empty[Filter] + private var current = start - 1 - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema + override def createPartitionReader(): InputPartitionReader[InternalRow] = { + new AdvancedInputPartition(start, end, requiredSchema) } - override def readSchema(): StructType = requiredSchema + override def close(): Unit = {} - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val (supported, unsupported) = filters.partition { - case GreaterThan("i", _: Int) => true - case _ => false - } - this.filters = supported - unsupported + override def next(): Boolean = { + current += 1 + current < end } - override def pushedFilters(): Array[Filter] = filters - - override def build(): ScanConfig = this -} - -class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val RangeInputPartition(start, end) = partition - new PartitionReader[InternalRow] { - private var current = start - 1 - - override def next(): Boolean = { - current += 1 - current < end - } - - override def get(): InternalRow = { - val values = requiredSchema.map(_.name).map { - case "i" => current - case "j" => -current - } - InternalRow.fromSeq(values) - } - - override def close(): Unit = {} + override def get(): InternalRow = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current } + InternalRow.fromSeq(values) } } -class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { +class SchemaRequiredDataSource extends DataSourceV2 with ReadSupport { - class ReadSupport(val schema: StructType) extends SimpleReadSupport { - override def fullSchema(): StructType = schema - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = - Array.empty + class Reader(val readSchema: StructType) extends DataSourceReader { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = + java.util.Collections.emptyList() } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + override def createReader(options: DataSourceOptions): DataSourceReader = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def createBatchReadSupport( - schema: StructType, options: DataSourceOptions): BatchReadSupport = { - new ReadSupport(schema) + override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = { + new Reader(schema) } } -class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) - } + class Reader extends DataSourceReader with SupportsScanColumnarBatch { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - ColumnarReaderFactory + override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { + java.util.Arrays.asList( + new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -object ColumnarReaderFactory extends PartitionReaderFactory { - private final val BATCH_SIZE = 20 +class BatchInputPartitionReader(start: Int, end: Int) + extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { - override def supportColumnarReads(partition: InputPartition): Boolean = true + private final val BATCH_SIZE = 20 + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch(Array(i, j)) - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - throw new UnsupportedOperationException - } + private var current = start - override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { - val RangeInputPartition(start, end) = partition - new PartitionReader[ColumnarBatch] { - private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch(Array(i, j)) - - private var current = start - - override def next(): Boolean = { - i.reset() - j.reset() - - var count = 0 - while (current < end && count < BATCH_SIZE) { - i.putInt(count, current) - j.putInt(count, -current) - current += 1 - count += 1 - } + override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this - if (count == 0) { - false - } else { - batch.setNumRows(count) - true - } - } + override def next(): Boolean = { + i.reset() + j.reset() - override def get(): ColumnarBatch = batch + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } - override def close(): Unit = batch.close() + if (count == 0) { + false + } else { + batch.setNumRows(count) + true } } + + override def get(): ColumnarBatch = { + batch + } + + override def close(): Unit = batch.close() } +class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { -class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { + class Reader extends DataSourceReader with SupportsReportPartitioning { + override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { // Note that we don't have same value of column `a` across partitions. - Array( - SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), - SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - SpecificReaderFactory + java.util.Arrays.asList( + new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) } - override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning + override def outputPartitioning(): Partitioning = new MyPartitioning } class MyPartitioning extends Partitioning { override def numPartitions(): Int = 2 override def satisfy(distribution: Distribution): Boolean = distribution match { - case c: ClusteredDistribution => c.clusteredColumns.contains("i") + case c: ClusteredDistribution => c.clusteredColumns.contains("a") case _ => false } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition +class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) + extends InputPartition[InternalRow] + with InputPartitionReader[InternalRow] { + assert(i.length == j.length) -object SpecificReaderFactory extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val p = partition.asInstanceOf[SpecificInputPartition] - new PartitionReader[InternalRow] { - private var current = -1 + private var current = -1 - override def next(): Boolean = { - current += 1 - current < p.i.length - } + override def createPartitionReader(): InputPartitionReader[InternalRow] = this - override def get(): InternalRow = InternalRow(p.i(current), p.j(current)) - - override def close(): Unit = {} - } + override def next(): Boolean = { + current += 1 + current < i.length } + + override def get(): InternalRow = InternalRow(i(current), j(current)) + + override def close(): Unit = {} } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 952241b0b6be..e1b8e9c44d72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,36 +18,34 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util.Optional +import java.util.{Collections, List => JList, Optional} import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration /** * A HDFS based transactional writable data source. - * Each task writes data to `target/_temporary/queryId/$jobId-$partitionId-$attemptNumber`. - * Each job moves files from `target/_temporary/queryId/` to `target`. + * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/jobId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 - with BatchReadSupportProvider with BatchWriteSupportProvider { +class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { private val schema = new StructType().add("i", "long").add("j", "long") - class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { + class Reader(path: String, conf: Configuration) extends DataSourceReader { + override def readSchema(): StructType = schema - override def fullSchema(): StructType = schema - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -55,23 +53,21 @@ class SimpleWritableDataSource extends DataSourceV2 val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.map { f => - CSVInputPartitionReader(f.getPath.toUri.toString) - }.toArray + val serializableConf = new SerializableConfiguration(conf) + new SimpleCSVInputPartitionReader( + f.getPath.toUri.toString, + serializableConf): InputPartition[InternalRow] + }.toList.asJava } else { - Array.empty + Collections.emptyList() } } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - val serializableConf = new SerializableConfiguration(conf) - new CSVReaderFactory(serializableConf) - } } - class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport { - override def createBatchWriterFactory(): DataWriterFactory = { + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { SimpleCounter.resetCounter - new CSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -80,7 +76,7 @@ class SimpleWritableDataSource extends DataSourceV2 override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) - val jobPath = new Path(new Path(finalPath, "_temporary"), queryId) + val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) val fs = jobPath.getFileSystem(conf) try { for (file <- fs.listStatus(jobPath).map(_.getPath)) { @@ -95,23 +91,23 @@ class SimpleWritableDataSource extends DataSourceV2 } override def abort(messages: Array[WriterCommitMessage]): Unit = { - val jobPath = new Path(new Path(path, "_temporary"), queryId) + val jobPath = new Path(new Path(path, "_temporary"), jobId) val fs = jobPath.getFileSystem(conf) fs.delete(jobPath, true) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + override def createReader(options: DataSourceOptions): DataSourceReader = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration - new ReadSupport(path.toUri.toString, conf) + new Reader(path.toUri.toString, conf) } - override def createBatchWriteSupport( - queryId: String, + override def createWriter( + jobId: String, schema: StructType, mode: SaveMode, - options: DataSourceOptions): Optional[BatchWriteSupport] = { + options: DataSourceOptions): Optional[DataSourceWriter] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) @@ -134,42 +130,39 @@ class SimpleWritableDataSource extends DataSourceV2 } val pathStr = path.toUri.toString - Optional.of(new WritSupport(queryId, pathStr, conf)) + Optional.of(new Writer(jobId, pathStr, conf)) } } -case class CSVInputPartitionReader(path: String) extends InputPartition +class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) + extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { -class CSVReaderFactory(conf: SerializableConfiguration) - extends PartitionReaderFactory { + @transient private var lines: Iterator[String] = _ + @transient private var currentLine: String = _ + @transient private var inputStream: FSDataInputStream = _ - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val path = partition.asInstanceOf[CSVInputPartitionReader].path + override def createPartitionReader(): InputPartitionReader[InternalRow] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) + inputStream = fs.open(filePath) + lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala + this + } - new PartitionReader[InternalRow] { - private val inputStream = fs.open(filePath) - private val lines = new BufferedReader(new InputStreamReader(inputStream)) - .lines().iterator().asScala - - private var currentLine: String = _ - - override def next(): Boolean = { - if (lines.hasNext) { - currentLine = lines.next() - true - } else { - false - } - } + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } - override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) + override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) - override def close(): Unit = { - inputStream.close() - } - } + override def close(): Unit = { + inputStream.close() } } @@ -190,11 +183,12 @@ private[v2] object SimpleCounter { } class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory { + extends DataWriterFactory[InternalRow] { - override def createWriter( + override def createDataWriter( partitionId: Int, - taskId: Long): DataWriter[InternalRow] = { + taskId: Long, + epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 491dc34afa14..35644c58cf79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -686,7 +686,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case r: StreamingExecutionRelation => r.source - case r: StreamingDataSourceV2Relation => r.readSupport + case r: StreamingDataSourceV2Relation => r.reader } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index fe77a1b4469c..0f15cd6e5a50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -299,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def latestOffset(): OffsetV2 = { + override def getEndOffset: OffsetV2 = { numTriggers += 1 - super.latestOffset() + super.getEndOffset } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1dd817545a96..0278e2a36890 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.CountDownLatch import scala.collection.mutable @@ -30,12 +32,13 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig} +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -212,17 +215,25 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi private def dataAdded: Boolean = currentOffset.offset != -1 - // latestOffset should take 50 ms the first time it is called after data is added - override def latestOffset(): OffsetV2 = synchronized { - if (dataAdded) clock.waitTillTime(1050) - super.latestOffset() + // setOffsetRange should take 50 ms the first time it is called after data is added + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.setOffsetRange(start, end) + } + } + + // getEndOffset should take 100 ms the first time it is called after data is added + override def getEndOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1150) + super.getEndOffset() } // getBatch should take 100 ms the first time it is called - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { - clock.waitTillTime(1150) - super.planInputPartitions(config) + clock.waitTillTime(1350) + super.planInputPartitions() } } } @@ -263,26 +274,34 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress when `latestOffset` is being called + // Test status and progress when setOffsetRange is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on `latestOffset` + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(50), // time = 1050 to unblock `latestOffset` + AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange AssertClockTime(1050), - // will block on `planInputPartitions` that needs 1350 - AssertStreamExecThreadIsWaitingForTime(1150), + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 + AssertOnQuery(_.status.isDataAvailable === false), + AssertOnQuery(_.status.isTriggerActive === true), + AssertOnQuery(_.status.message.startsWith("Getting offsets from")), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), + + AdvanceManualClock(100), // time = 1150 to unblock getEndOffset + AssertClockTime(1150), + // will block on planInputPartitions that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1350), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(100), // time = 1150 to unblock `planInputPartitions` - AssertClockTime(1150), + AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions + AssertClockTime(1350), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), @@ -290,7 +309,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(350), // time = 1500 to unblock map task + AdvanceManualClock(150), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger @@ -310,10 +329,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("latestOffset") === 50) - assert(progress.durationMs.get("queryPlanning") === 100) + assert(progress.durationMs.get("setOffsetRange") === 50) + assert(progress.durationMs.get("getEndOffset") === 100) + assert(progress.durationMs.get("queryPlanning") === 200) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 350) + assert(progress.durationMs.get("addBatch") === 150) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index d6819eacd07c..4f198819b58d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -22,15 +22,16 @@ import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType} class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { case class LongPartitionOffset(offset: Long) extends PartitionOffset @@ -43,8 +44,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamingWriteSupport], - mock[ContinuousReadSupport], + mock[StreamWriter], + mock[ContinuousReader], mock[ContinuousExecution], coordinatorId, startEpoch, @@ -72,26 +73,26 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val partitionReader = new ContinuousPartitionReader[InternalRow] { - var index = -1 - var curr: UnsafeRow = _ - - override def next() = { - curr = queue.take() - index += 1 - true - } + val factory = new InputPartition[InternalRow] { + override def createPartitionReader() = new ContinuousInputPartitionReader[InternalRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } - override def get = curr + override def get = curr - override def getOffset = LongPartitionOffset(index) + override def getOffset = LongPartitionOffset(index) - override def close() = {} + override def close() = {} + } } val reader = new ContinuousQueuedDataReader( - 0, - partitionReader, - new StructType().add("i", "int"), + new ContinuousDataSourceRDDPartition(0, factory), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 3d21bc63e0cc..4980b0cd41f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 3c973d8ebc70..82836dced9df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,20 +40,20 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writeSupport: StreamingWriteSupport = _ + private var writer: StreamWriter = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { - val reader = mock[ContinuousReadSupport] - writeSupport = mock[StreamingWriteSupport] + val reader = mock[ContinuousReader] + writer = mock[StreamWriter] query = mock[ContinuousExecution] - orderVerifier = inOrder(writeSupport, query) + orderVerifier = inOrder(writer, query) spark = new TestSparkSession() epochCoordinator - = EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get) + = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) } test("single epoch") { @@ -209,12 +209,12 @@ class EpochCoordinatorSuite } private def verifyCommit(epoch: Long): Unit = { - orderVerifier.verify(writeSupport).commit(eqTo(epoch), any()) + orderVerifier.verify(writer).commit(eqTo(epoch), any()) orderVerifier.verify(query).commit(epoch) } private def verifyNoCommitFor(epoch: Long): Unit = { - verify(writeSupport, never()).commit(eqTo(epoch), any()) + verify(writer, never()).commit(eqTo(epoch), any()) verify(query, never()).commit(epoch) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index aeef4c8fe933..52b833a19c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -17,74 +17,73 @@ package org.apache.spark.sql.streaming.sources +import java.util.Optional + import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder} -import org.apache.spark.sql.sources.v2.reader.streaming._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSupport { - override def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} - override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - override def fullSchema(): StructType = StructType(Seq()) - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = null - override def initialOffset(): Offset = RateStreamOffset(Map()) - override def latestOffset(): Offset = RateStreamOffset(Map()) - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = null - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - throw new IllegalStateException("fake source - cannot actually read") - } - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - throw new IllegalStateException("fake source - cannot actually read") - } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { +case class FakeReader() extends MicroBatchReader with ContinuousReader { + def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} + def getStartOffset: Offset = RateStreamOffset(Map()) + def getEndOffset: Offset = RateStreamOffset(Map()) + def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + def commit(end: Offset): Unit = {} + def readSchema(): StructType = StructType(Seq()) + def stop(): Unit = {} + def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + def setStartOffset(start: Optional[Offset]): Unit = {} + + def planInputPartitions(): java.util.ArrayList[InputPartition[InternalRow]] = { throw new IllegalStateException("fake source - cannot actually read") } } -trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider { - override def createMicroBatchReadSupport( +trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { + override def createMicroBatchReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = FakeReadSupport() + options: DataSourceOptions): MicroBatchReader = FakeReader() } -trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider { - override def createContinuousReadSupport( +trait FakeContinuousReadSupport extends ContinuousReadSupport { + override def createContinuousReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = FakeReadSupport() + options: DataSourceOptions): ContinuousReader = FakeReader() } -trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { - override def createStreamingWriteSupport( +trait FakeStreamWriteSupport extends StreamWriteSupport { + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { + options: DataSourceOptions): StreamWriter = { throw new IllegalStateException("fake sink - cannot actually write") } } -class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupportProvider { +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { override def shortName(): String = "fake-read-microbatch-only" } -class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupportProvider { +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { override def shortName(): String = "fake-read-continuous-only" } class FakeReadBothModes extends DataSourceRegister - with FakeMicroBatchReadSupportProvider with FakeContinuousReadSupportProvider { + with FakeMicroBatchReadSupport with FakeContinuousReadSupport { override def shortName(): String = "fake-read-microbatch-continuous" } @@ -92,7 +91,7 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWriteSupportProvider extends DataSourceRegister with FakeStreamingWriteSupportProvider { +class FakeWrite extends DataSourceRegister with FakeStreamWriteSupport { override def shortName(): String = "fake-write-microbatch-continuous" } @@ -107,8 +106,8 @@ class FakeSink extends Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } -class FakeWriteSupportProviderV1Fallback extends DataSourceRegister - with FakeStreamingWriteSupportProvider with StreamSinkProvider { +class FakeWriteV1Fallback extends DataSourceRegister + with FakeStreamWriteSupport with StreamSinkProvider { override def createSink( sqlContext: SQLContext, @@ -191,11 +190,11 @@ class StreamingDataSourceV2Suite extends StreamTest { val v2Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteSupportProviderV1Fallback]) + .isInstanceOf[FakeWriteV1Fallback]) // Ensure we create a V1 sink with the config. Note the config is a comma separated // list, including other fake entries. - val fullSinkName = classOf[FakeWriteSupportProviderV1Fallback].getName + val fullSinkName = "org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback" withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { val v1Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) @@ -219,37 +218,35 @@ class StreamingDataSourceV2Suite extends StreamTest { val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) + case (_: MicroBatchReadSupport, _: StreamWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: ContinuousReadSupportProvider, _: StreamingWriteSupportProvider, - _: ContinuousTrigger) => + case (_: ContinuousReadSupport, _: StreamWriteSupport, _: ContinuousTrigger) => testPositiveCase(read, write, trigger) // Invalid - can't read at all case (r, _, _) - if !r.isInstanceOf[MicroBatchReadSupportProvider] - && !r.isInstanceOf[ContinuousReadSupportProvider] => + if !r.isInstanceOf[MicroBatchReadSupport] + && !r.isInstanceOf[ContinuousReadSupport] => testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] => + case (_, w, _) if !w.isInstanceOf[StreamWriteSupport] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") // Invalid - trigger is continuous but reader is not - case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) - if !r.isInstanceOf[ContinuousReadSupportProvider] => + case (r, _: StreamWriteSupport, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupport] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") // Invalid - trigger is microbatch but reader is not case (r, _, t) - if !r.isInstanceOf[MicroBatchReadSupportProvider] && - !t.isInstanceOf[ContinuousTrigger] => + if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, s"Data source $read does not support microbatch processing") }