diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 5f0b195fcfcb..491859e4bd85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2} + /** * A simple offset for sources that produce a single linear stream of data. */ -case class LongOffset(offset: Long) extends Offset { +case class LongOffset(offset: Long) extends OffsetV2 { override val json = offset.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 70407f0580f9..036e52e6bf9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -273,7 +273,7 @@ class MicroBatchExecution( toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), Optional.empty()) - (s, Some(s.getEndOffset)) + (s, Option(s.getEndOffset)) } }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -396,10 +396,14 @@ class MicroBatchExecution( case (reader: MicroBatchReader, available) if committedOffsets.get(reader).map(_ != available).getOrElse(true) => val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + val availableV2: OffsetV2 = available match { + case v1: SerializedOffset => reader.deserializeOffset(v1.json) + case v2: OffsetV2 => v2 + } reader.setOffsetRange( toJava(current), - Optional.of(available.asInstanceOf[OffsetV2])) - logDebug(s"Retrieving data from $reader: $current -> $available") + Optional.of(availableV2)) + logDebug(s"Retrieving data from $reader: $current -> $availableV2") Some(reader -> new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 509a69dd922f..98c84969230b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal @@ -31,7 +32,8 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.sources.v2.reader.{DataReader, ReadTask} +import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -51,9 +53,10 @@ object MemoryStream { * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends Source with Logging { + extends MicroBatchReader with Logging { protected val encoder = encoderFor[A] - protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession) + private val attributes = encoder.schema.toAttributes + protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -66,6 +69,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) + private var startOffset = new LongOffset(-1) + private var endOffset = new LongOffset(-1) + /** * Last offset that was discarded, or -1 if no commits have occurred. Note that the value * -1 is used in calculations below and isn't just an arbitrary constant. @@ -73,8 +79,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) - def schema: StructType = encoder.schema - def toDS(): Dataset[A] = { Dataset(sqlContext.sparkSession, logicalPlan) } @@ -89,7 +93,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = { val encoded = data.toVector.map(d => encoder.toRow(d).copy()) - val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true) + val plan = new LocalRelation(attributes, encoded, isStreaming = false) val ds = Dataset[A](sqlContext.sparkSession, plan) logDebug(s"Adding ds: $ds") this.synchronized { @@ -101,19 +105,25 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None - } else { - Some(currentOffset) + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + if (start.isPresent) { + startOffset = start.get().asInstanceOf[LongOffset] } + endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] } - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + override def readSchema(): StructType = encoder.schema + + override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) + + override def getStartOffset: OffsetV2 = if (startOffset.offset == -1) null else startOffset + + override def getEndOffset: OffsetV2 = if (endOffset.offset == -1) null else endOffset + + override def createReadTasks(): ju.List[ReadTask[Row]] = { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 + val startOrdinal = startOffset.offset.toInt + 1 + val endOrdinal = endOffset.offset.toInt + 1 // Internal buffer only holds the batches after lastCommittedOffset. val newBlocks = synchronized { @@ -123,19 +133,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) batches.slice(sliceStart, sliceEnd) } - if (newBlocks.isEmpty) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) - newBlocks - .map(_.toDF()) - .reduceOption(_ union _) - .getOrElse { - sys.error("No data selected!") - } + newBlocks.map { ds => + val items = ds.toDF().collect() + new MemoryStreamReadTask(items).asInstanceOf[ReadTask[Row]] + }.asJava } private def generateDebugString( @@ -153,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def commit(end: Offset): Unit = synchronized { + override def commit(end: OffsetV2): Unit = synchronized { def check(newOffset: LongOffset): Unit = { val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt @@ -181,6 +184,24 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } +class MemoryStreamReadTask(records: Array[Row]) extends ReadTask[Row] { + override def createDataReader(): DataReader[Row] = new MemoryStreamDataReader(records) +} + +class MemoryStreamDataReader(records: Array[Row]) extends DataReader[Row] { + private var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the array. + currentIndex += 1 + currentIndex < records.length + } + + override def get(): Row = records(currentIndex) + + override def close(): Unit = {} +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index c0ed12cec25e..509f69430c5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -152,7 +152,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] { } class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - var currentIndex = -1 + private var currentIndex = -1 override def next(): Boolean = { // Return true as long as the new index is in the seq. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c65e5d3dd75c..d1a04833390f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d46461fa9bf6..1423ab1a4996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -116,7 +116,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { (source, source.addData(data)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 9ff02dee288f..a0604b5fdcdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2} import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol @@ -273,9 +274,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getOffset: Option[Offset] = { + override def getEndOffset: OffsetV2 = { numTriggers += 1 - super.getOffset + super.getEndOffset } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 76201c63a270..519ee64a50f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.streaming +import java.{util => ju} import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils @@ -29,10 +30,12 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType import org.apache.spark.util.ManualClock @@ -207,18 +210,18 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi /** Custom MemoryStream that waits for manual clock to reach a time */ val inputData = new MemoryStream[Int](0, sqlContext) { // getOffset should take 50 ms the first time it is called - override def getOffset: Option[Offset] = { - val offset = super.getOffset - if (offset.nonEmpty) { + override def getEndOffset: OffsetV2 = { + val offset = super.getEndOffset + if (offset != null) { clock.waitTillTime(1050) } offset } // getBatch should take 100 ms the first time it is called - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - if (start.isEmpty) clock.waitTillTime(1150) - super.getBatch(start, end) + override def createReadTasks(): ju.List[ReadTask[Row]] = { + if (getStartOffset.asInstanceOf[LongOffset].offset == -1L) clock.waitTillTime(1150) + super.createReadTasks() } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index dc92ad3b0c1a..2d0c64cd68d7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} @@ -101,6 +102,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override def strategies: Seq[Strategy] = { experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ Seq( + DataSourceV2Strategy, FileSourceStrategy, DataSourceStrategy(conf), SpecialLimits,