diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala index 5b114242558d..0063318db332 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala @@ -22,6 +22,9 @@ import java.nio.charset.StandardCharsets._ import scala.io.{Source => IOSource} +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + import org.apache.spark.sql.SparkSession /** @@ -43,36 +46,28 @@ import org.apache.spark.sql.SparkSession * line 2: metadata (optional json string) */ class CommitLog(sparkSession: SparkSession, path: String) - extends HDFSMetadataLog[String](sparkSession, path) { + extends HDFSMetadataLog[CommitMetadata](sparkSession, path) { import CommitLog._ - def add(batchId: Long): Unit = { - super.add(batchId, EMPTY_JSON) - } - - override def add(batchId: Long, metadata: String): Boolean = { - throw new UnsupportedOperationException( - "CommitLog does not take any metadata, use 'add(batchId)' instead") - } - - override protected def deserialize(in: InputStream): String = { + override protected def deserialize(in: InputStream): CommitMetadata = { // called inside a try-finally where the underlying stream is closed in the caller val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file in the offset commit log") } parseVersion(lines.next.trim, VERSION) - EMPTY_JSON + val metadataJson = if (lines.hasNext) lines.next else EMPTY_JSON + CommitMetadata(metadataJson) } - override protected def serialize(metadata: String, out: OutputStream): Unit = { + override protected def serialize(metadata: CommitMetadata, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller out.write(s"v${VERSION}".getBytes(UTF_8)) out.write('\n') // write metadata - out.write(EMPTY_JSON.getBytes(UTF_8)) + out.write(metadata.json.getBytes(UTF_8)) } } @@ -81,3 +76,13 @@ object CommitLog { private val EMPTY_JSON = "{}" } + +case class CommitMetadata(nextBatchWatermarkMs: Long = 0) { + def json: String = Serialization.write(this)(CommitMetadata.format) +} + +object CommitMetadata { + implicit val format = Serialization.formats(NoTypeHints) + + def apply(json: String): CommitMetadata = Serialization.read[CommitMetadata](json) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 16651dd060d7..e16f1da6493a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -265,7 +265,7 @@ class MicroBatchExecution( * latest batch id in the offset log, then we can safely move to the next batch * i.e., committedBatchId + 1 */ commitLog.getLatest() match { - case Some((latestCommittedBatchId, _)) => + case Some((latestCommittedBatchId, commitMetadata)) => if (latestBatchId == latestCommittedBatchId) { /* The last batch was successfully committed, so we can safely process a * new next batch but first: @@ -283,7 +283,9 @@ class MicroBatchExecution( currentBatchId = latestCommittedBatchId + 1 isCurrentBatchConstructed = false committedOffsets ++= availableOffsets - // Construct a new batch be recomputing availableOffsets + watermarkTracker.setWatermark( + math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs)) + println(s"Recovered at $currentBatchId with wm ${watermarkTracker.currentWatermark}") } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -533,11 +535,11 @@ class MicroBatchExecution( } withProgressLocked { - commitLog.add(currentBatchId) + watermarkTracker.updateWatermark(lastExecution.executedPlan) + commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)) committedOffsets ++= availableOffsets awaitProgressLockCondition.signalAll() } - watermarkTracker.updateWatermark(lastExecution.executedPlan) logDebug(s"Completed batch ${currentBatchId}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index a0bb8292d776..b07934ac3a5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -311,7 +311,7 @@ class ContinuousExecution( assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") synchronized { if (queryExecutionThread.isAlive) { - commitLog.add(epoch) + commitLog.add(epoch, CommitMetadata()) val offset = continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) committedOffsets ++= Seq(continuousSources(0) -> offset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 58ed9790ea12..e6a22711aab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -462,6 +462,92 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } } + test("SPARK-24699: watermark should behave the same for Trigger ProcessingTime / Once") { + val watermarkSeconds = 2 + val windowSeconds = 5 + val source = MemoryStream[Int] + val df = { + source + .toDF() + .withColumn("eventTime", 'value cast "timestamp") + .withWatermark("eventTime", s"$watermarkSeconds seconds") + .groupBy(window($"eventTime", s"$windowSeconds seconds") as 'window) + .count() + .select('window.getField("start").cast("long").as[Long], 'count.as[Long]) + } + val (one, two, three, four) = ( + Seq(1, 1, 2, 3, 4, 4, 6), + Seq(7, 8, 9), + Seq(11, 12, 13, 14, 14), + Seq(15) + ) + val (resultsAfterOne, resultsAfterTwo, resultsAfterThree, resultsAfterFour) = ( + CheckAnswer(), + CheckAnswer(), + CheckAnswer(0 -> 6), + CheckAnswer(0 -> 6, 5 -> 4) + ) + val (statsAfterOne, statsAfterTwo, statsAfterThree, statsAfterFour) = ( + assertEventStats( + min = one.min, + max = one.max, + avg = one.sum.toDouble / one.size, + watermark = 0, + "first" + ), + assertEventStats( + min = two.min, + max = two.max, + avg = two.sum.toDouble / two.size, + watermark = one.max - watermarkSeconds, + "second" + ), + assertEventStats( + min = three.min, + max = three.max, + avg = three.sum.toDouble / three.size, + watermark = two.max - watermarkSeconds, + "third" + ), + assertEventStats( + min = four.min, + max = four.max, + avg = four.sum.toDouble / four.size, + watermark = three.max - watermarkSeconds, + "fourth" + ) + ) + + testStream(df)( + StartStream(Trigger.Once), + StopStream, + + AddData(source, one: _*), + StartStream(Trigger.Once), + resultsAfterOne, + statsAfterOne, + StopStream, + + AddData(source, two: _*), + StartStream(Trigger.Once), + resultsAfterTwo, + statsAfterTwo, + StopStream, + + AddData(source, three: _*), + StartStream(Trigger.Once), + resultsAfterThree, + statsAfterThree, + StopStream, + + AddData(source, four: _*), + StartStream(Trigger.Once), + resultsAfterFour, + statsAfterFour, + StopStream + ) + } + test("test no-data flag") { val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key @@ -632,10 +718,26 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } } + private def assertEventStats( + min: Long, + max: Long, + avg: Double, + watermark: Long, + name: String = "event stats"): AssertOnQuery = assertEventStats { e => + assert(e.get("min") === formatTimestamp(min), s"[$name]: min value") + assert(e.get("max") === formatTimestamp(max), s"[$name]: max value") + assert(e.get("avg") === formatTimestamp(avg), s"[$name]: avg value") + assert(e.get("watermark") === formatTimestamp(watermark), s"[$name]: watermark value") + } + private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 timestampFormat.setTimeZone(ju.TimeZone.getTimeZone("UTC")) private def formatTimestamp(sec: Long): String = { timestampFormat.format(new ju.Date(sec * 1000)) } + + private def formatTimestamp(sec: Double): String = { + timestampFormat.format(new ju.Date((sec * 1000).toLong)) + } }