diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index 2cc54107f8b8..d4f551179285 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -17,52 +17,148 @@ package org.apache.spark.sql.execution.streaming +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + + +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + val encoder = encoderFor[T].resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + new StreamWriter with SupportsWriteInternalRow { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + val byteStream = new ByteArrayOutputStream() + val objectStream = new ObjectOutputStream(byteStream) + objectStream.writeObject(writer) + ForeachWriterFactory(byteStream.toByteArray, encoder) + } + } + } +} + +case class ForeachWriterFactory[T: Encoder]( + serializedWriter: Array[Byte], + encoder: ExpressionEncoder[T]) + extends DataWriterFactory[InternalRow] { + override def createDataWriter(partitionId: Int, attemptNumber: Int): ForeachDataWriter[T] = { + new ForeachDataWriter(serializedWriter, encoder, partitionId) + } +} /** - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by - * [[ForeachWriter]]. + * A [[DataWriter]] for the foreach sink. + * + * Note that [[ForeachWriter]] has the following lifecycle, and (as was true in the V1 sink API) + * assumes that it's never reused: + * * [create writer] + * * open(partitionId, batchId) + * * if open() returned true: write, write, write, ... + * * close() + * while DataSourceV2 writers have a slightly different lifecycle and will be reused for multiple + * epochs in the continuous processing engine: + * * [create writer] + * * write, write, write, ... + * * commit() * - * @param writer The [[ForeachWriter]] to process all data. - * @tparam T The expected type of the sink. + * The bulk of the implementation here is a shim between these two models. + * + * @param serializedWriter a serialized version of the user-provided [[ForeachWriter]] + * @param encoder encoder from [[Row]] to the type param [[T]] + * @param partitionId the ID of the partition this data writer is responsible for + * + * @tparam T the type of data to be handled by the writer */ -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - // This logic should've been as simple as: - // ``` - // data.as[T].foreachPartition { iter => ... } - // ``` - // - // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will - // create a new plan. Because StreamExecution uses the existing plan to collect metrics and - // update watermark, we should never create a new plan. Otherwise, metrics and watermark are - // updated in the new plan, and StreamExecution cannot retrieval them. - // - // Hence, we need to manually convert internal rows to objects using encoder. - val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - data.queryExecution.toRdd.foreachPartition { iter => - if (writer.open(TaskContext.getPartitionId(), batchId)) { - try { - while (iter.hasNext) { - writer.process(encoder.fromRow(iter.next())) - } - } catch { - case e: Throwable => - writer.close(e) - throw e - } - writer.close(null) - } else { - writer.close(null) +class ForeachDataWriter[T : Encoder]( + serializedWriter: Array[Byte], + encoder: ExpressionEncoder[T], + partitionId: Int) + extends DataWriter[InternalRow] { + private val initialEpochId: Long = { + // Start with the microbatch ID. If it's not there, we're in continuous execution, + // so get the start epoch. + // This ID will be incremented as commits happen. + TaskContext.get().getLocalProperty(MicroBatchExecution.BATCH_ID_KEY) match { + case null => TaskContext.get().getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + case batch => batch.toLong + } + } + + // A small state machine representing the lifecycle of the underlying ForeachWriter. + // * CLOSED means close() has been called. + // * OPENED means open() was called and returned true. + // * OPENED_SKIP_PROCESSING means open() was called and returned false. + private object WriterState extends Enumeration { + type WriterState = Value + val CLOSED, OPENED, OPENED_SKIP_PROCESSING = Value + } + import WriterState._ + + private var writer: ForeachWriter[T] = _ + private var state: WriterState = _ + private var currentEpochId = initialEpochId + + private def openAndSetState(epochId: Long) = { + writer = new ObjectInputStream(new ByteArrayInputStream(serializedWriter)).readObject() + .asInstanceOf[ForeachWriter[T]] + + writer.open(partitionId, epochId) match { + case true => state = OPENED + case false => state = OPENED_SKIP_PROCESSING + } + } + + openAndSetState(initialEpochId) + + override def write(record: InternalRow): Unit = { + try { + state match { + case OPENED => writer.process(encoder.fromRow(record)) + case OPENED_SKIP_PROCESSING => () + case CLOSED => + // First record of a new epoch, so we need to open a new writer for it. + openAndSetState(currentEpochId) + writer.process(encoder.fromRow(record)) } + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + // Close if the writer got opened for this epoch. + state match { + case CLOSED => () + case _ => writer.close(null) } + state = CLOSED + currentEpochId += 1 + ForeachWriterCommitMessage } - override def toString(): String = "ForeachSink" + override def abort(): Unit = {} } + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 84564b6639ac..2325e0ecc656 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -467,6 +467,9 @@ class MicroBatchExecution( case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } + sparkSession.sparkContext.setLocalProperty( + MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString) + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -507,4 +510,7 @@ class MicroBatchExecution( } } +object MicroBatchExecution { + val BATCH_ID_KEY = "sql.streaming.microbatch.batchId" +} object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 2fc903168cfa..10286df75c15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index b249dd41a84a..3e79fd416628 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.io.Serializable import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable @@ -26,7 +27,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.functions.{count, window} -import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger} import org.apache.spark.sql.test.SharedSQLContext class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { @@ -141,7 +142,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() } assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "error") + assert(e.getCause.getCause.getCause.getMessage === "error") assert(query.isActive === false) val allEvents = ForeachSinkSuite.allEvents() @@ -255,6 +256,89 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.stop() } } + + testQuietly("foreach does not reuse writers") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + val query = input.toDS().repartition(1).writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .foreach(new TestForeachWriter() { + override def process(value: Int): Unit = { + super.process(this.hashCode()) + } + }).start() + input.addData(0) + query.processAllAvailable() + input.addData(0) + query.processAllAvailable() + + val allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 2) + assert(allEvents(0)(1).isInstanceOf[ForeachSinkSuite.Process[Int]]) + val firstWriterId = allEvents(0)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value + assert(allEvents(1)(1).isInstanceOf[ForeachSinkSuite.Process[Int]]) + assert( + allEvents(1)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value != firstWriterId, + "writer was reused!") + } + } + + testQuietly("foreach sink for continuous query") { + withTempDir { checkpointDir => + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "5") + .load() + .select('value.cast("INT")) + .map(r => r.getInt(0)) + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .trigger(Trigger.Continuous(500)) + .foreach(new TestForeachWriter with Serializable { + override def process(value: Int): Unit = { + super.process(this.hashCode()) + } + }).start() + try { + // Wait until we get 3 epochs with at least 3 events in them. This means we'll see + // open, close, and at least 1 process. + eventually(timeout(streamingTimeout)) { + // Check + assert(ForeachSinkSuite.allEvents().count(_.size >= 3) === 3) + } + + val allEvents = ForeachSinkSuite.allEvents().filter(_.size >= 3) + // Check open and close events. + allEvents(0).head match { + case ForeachSinkSuite.Open(0, _) => + case e => assert(false, s"unexpected event $e") + } + allEvents(1).head match { + case ForeachSinkSuite.Open(0, _) => + case e => assert(false, s"unexpected event $e") + } + allEvents(2).head match { + case ForeachSinkSuite.Open(0, _) => + case e => assert(false, s"unexpected event $e") + } + assert(allEvents(0).last == ForeachSinkSuite.Close(None)) + assert(allEvents(1).last == ForeachSinkSuite.Close(None)) + assert(allEvents(2).last == ForeachSinkSuite.Close(None)) + + // Check the first Process event in each epoch, and also check the writer IDs + // we packed in to make sure none got reused. + val writerIds = (0 to 2).map { i => + allEvents(i)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value + } + assert( + writerIds.toSet.size == 3, + s"writer was reused! expected 3 unique writers but saw $writerIds") + } finally { + query.stop() + } + } + } } /** A global object to collect events in the executor */