Skip to content

Commit 7e54a89

Browse files
committed
Fixed bug
Co-authored-by: Tathagata Das <tathagata.das1565@gmail.com> Co-authored-by: c-horn
1 parent 006e798 commit 7e54a89

File tree

4 files changed

+127
-19
lines changed

4 files changed

+127
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ import java.nio.charset.StandardCharsets._
2222

2323
import scala.io.{Source => IOSource}
2424

25+
import org.json4s.NoTypeHints
26+
import org.json4s.jackson.Serialization
27+
2528
import org.apache.spark.sql.SparkSession
2629

2730
/**
@@ -43,36 +46,28 @@ import org.apache.spark.sql.SparkSession
4346
* line 2: metadata (optional json string)
4447
*/
4548
class CommitLog(sparkSession: SparkSession, path: String)
46-
extends HDFSMetadataLog[String](sparkSession, path) {
49+
extends HDFSMetadataLog[CommitMetadata](sparkSession, path) {
4750

4851
import CommitLog._
4952

50-
def add(batchId: Long): Unit = {
51-
super.add(batchId, EMPTY_JSON)
52-
}
53-
54-
override def add(batchId: Long, metadata: String): Boolean = {
55-
throw new UnsupportedOperationException(
56-
"CommitLog does not take any metadata, use 'add(batchId)' instead")
57-
}
58-
59-
override protected def deserialize(in: InputStream): String = {
53+
override protected def deserialize(in: InputStream): CommitMetadata = {
6054
// called inside a try-finally where the underlying stream is closed in the caller
6155
val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
6256
if (!lines.hasNext) {
6357
throw new IllegalStateException("Incomplete log file in the offset commit log")
6458
}
6559
parseVersion(lines.next.trim, VERSION)
66-
EMPTY_JSON
60+
val metadataJson = if (lines.hasNext) lines.next else EMPTY_JSON
61+
CommitMetadata(metadataJson)
6762
}
6863

69-
override protected def serialize(metadata: String, out: OutputStream): Unit = {
64+
override protected def serialize(metadata: CommitMetadata, out: OutputStream): Unit = {
7065
// called inside a try-finally where the underlying stream is closed in the caller
7166
out.write(s"v${VERSION}".getBytes(UTF_8))
7267
out.write('\n')
7368

7469
// write metadata
75-
out.write(EMPTY_JSON.getBytes(UTF_8))
70+
out.write(metadata.json.getBytes(UTF_8))
7671
}
7772
}
7873

@@ -81,3 +76,13 @@ object CommitLog {
8176
private val EMPTY_JSON = "{}"
8277
}
8378

79+
80+
case class CommitMetadata(nextBatchWatermarkMs: Long = 0) {
81+
def json: String = Serialization.write(this)(CommitMetadata.format)
82+
}
83+
84+
object CommitMetadata {
85+
implicit val format = Serialization.formats(NoTypeHints)
86+
87+
def apply(json: String): CommitMetadata = Serialization.read[CommitMetadata](json)
88+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class MicroBatchExecution(
265265
* latest batch id in the offset log, then we can safely move to the next batch
266266
* i.e., committedBatchId + 1 */
267267
commitLog.getLatest() match {
268-
case Some((latestCommittedBatchId, _)) =>
268+
case Some((latestCommittedBatchId, commitMetadata)) =>
269269
if (latestBatchId == latestCommittedBatchId) {
270270
/* The last batch was successfully committed, so we can safely process a
271271
* new next batch but first:
@@ -283,7 +283,9 @@ class MicroBatchExecution(
283283
currentBatchId = latestCommittedBatchId + 1
284284
isCurrentBatchConstructed = false
285285
committedOffsets ++= availableOffsets
286-
// Construct a new batch be recomputing availableOffsets
286+
watermarkTracker.setWatermark(
287+
math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs))
288+
println(s"Recovered at $currentBatchId with wm ${watermarkTracker.currentWatermark}")
287289
} else if (latestCommittedBatchId < latestBatchId - 1) {
288290
logWarning(s"Batch completion log latest batch id is " +
289291
s"${latestCommittedBatchId}, which is not trailing " +
@@ -533,11 +535,11 @@ class MicroBatchExecution(
533535
}
534536

535537
withProgressLocked {
536-
commitLog.add(currentBatchId)
538+
watermarkTracker.updateWatermark(lastExecution.executedPlan)
539+
commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
537540
committedOffsets ++= availableOffsets
538541
awaitProgressLockCondition.signalAll()
539542
}
540-
watermarkTracker.updateWatermark(lastExecution.executedPlan)
541543
logDebug(s"Completed batch ${currentBatchId}")
542544
}
543545

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ class ContinuousExecution(
311311
assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit")
312312
synchronized {
313313
if (queryExecutionThread.isAlive) {
314-
commitLog.add(epoch)
314+
commitLog.add(epoch, CommitMetadata())
315315
val offset =
316316
continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json)
317317
committedOffsets ++= Seq(continuousSources(0) -> offset)

sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,91 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche
462462
}
463463
}
464464

465+
test("SPARK-24699: watermark should behave the same for Trigger ProcessingTime / Once") {
466+
val watermarkSeconds = 2
467+
val windowSeconds = 5
468+
val source = MemoryStream[Int]
469+
val df = source
470+
.toDF()
471+
.withColumn("eventTime", 'value cast "timestamp")
472+
.withWatermark("eventTime", s"$watermarkSeconds seconds")
473+
.groupBy(window($"eventTime", s"$windowSeconds seconds") as 'window)
474+
.count()
475+
.select('window.getField("start").cast("long").as[Long], 'count.as[Long])
476+
477+
val (one, two, three, four) = (
478+
Seq(1, 1, 2, 3, 4, 4, 6),
479+
Seq(7, 8, 9),
480+
Seq(11, 12, 13, 14, 14),
481+
Seq(15)
482+
)
483+
val (resultsAfterOne, resultsAfterTwo, resultsAfterThree, resultsAfterFour) = (
484+
CheckAnswer(),
485+
CheckAnswer(),
486+
CheckAnswer(0 -> 6),
487+
CheckAnswer(0 -> 6, 5 -> 4)
488+
)
489+
val (statsAfterOne, statsAfterTwo, statsAfterThree, statsAfterFour) = (
490+
assertEventStats(
491+
min = one.min,
492+
max = one.max,
493+
avg = one.sum.toDouble / one.size,
494+
watermark = 0,
495+
"first"
496+
),
497+
assertEventStats(
498+
min = two.min,
499+
max = two.max,
500+
avg = two.sum.toDouble / two.size,
501+
watermark = one.max - watermarkSeconds,
502+
"second"
503+
),
504+
assertEventStats(
505+
min = three.min,
506+
max = three.max,
507+
avg = three.sum.toDouble / three.size,
508+
watermark = two.max - watermarkSeconds,
509+
"third"
510+
),
511+
assertEventStats(
512+
min = four.min,
513+
max = four.max,
514+
avg = four.sum.toDouble / four.size,
515+
watermark = three.max - watermarkSeconds,
516+
"fourth"
517+
)
518+
)
519+
520+
testStream(df)(
521+
StartStream(Trigger.Once),
522+
StopStream,
523+
524+
AddData(source, one: _*),
525+
StartStream(Trigger.Once),
526+
resultsAfterOne,
527+
statsAfterOne,
528+
StopStream,
529+
530+
AddData(source, two: _*),
531+
StartStream(Trigger.Once),
532+
resultsAfterTwo,
533+
statsAfterTwo,
534+
StopStream,
535+
536+
AddData(source, three: _*),
537+
StartStream(Trigger.Once),
538+
resultsAfterThree,
539+
statsAfterThree,
540+
StopStream,
541+
542+
AddData(source, four: _*),
543+
StartStream(Trigger.Once),
544+
resultsAfterFour,
545+
statsAfterFour,
546+
StopStream
547+
)
548+
}
549+
465550
test("test no-data flag") {
466551
val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key
467552

@@ -632,10 +717,26 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche
632717
}
633718
}
634719

720+
private def assertEventStats(
721+
min: Long,
722+
max: Long,
723+
avg: Double,
724+
watermark: Long,
725+
name: String = "event stats"): AssertOnQuery = assertEventStats { e =>
726+
assert(e.get("min") === formatTimestamp(min), s"[$name]: min value")
727+
assert(e.get("max") === formatTimestamp(max), s"[$name]: max value")
728+
assert(e.get("avg") === formatTimestamp(avg), s"[$name]: avg value")
729+
assert(e.get("watermark") === formatTimestamp(watermark), s"[$name]: watermark value")
730+
}
731+
635732
private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601
636733
timestampFormat.setTimeZone(ju.TimeZone.getTimeZone("UTC"))
637734

638735
private def formatTimestamp(sec: Long): String = {
639736
timestampFormat.format(new ju.Date(sec * 1000))
640737
}
738+
739+
private def formatTimestamp(sec: Double): String = {
740+
timestampFormat.format(new ju.Date((sec * 1000).toLong))
741+
}
641742
}

0 commit comments

Comments
 (0)