From bf6d2b5c0fbd7c67934f057daed9e0acd0dfb66c Mon Sep 17 00:00:00 2001 From: ericm-db Date: Tue, 18 Nov 2025 15:45:28 -0800 Subject: [PATCH 1/5] Introducing OffsetMap to enable Named Streaming Sources --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 4 +- .../v2/state/StateDataSource.scala | 4 +- .../checkpointing/AsyncOffsetSeqLog.scala | 2 +- .../streaming/checkpointing/OffsetSeq.scala | 92 +++++++++++++++--- .../checkpointing/OffsetSeqLog.scala | 95 +++++++++++++++---- .../continuous/ContinuousExecution.scala | 9 +- .../AcceptsLatestSeenOffsetHandler.scala | 9 +- ...cProgressTrackingMicroBatchExecution.scala | 15 +-- .../runtime/MicroBatchExecution.scala | 39 +++++--- .../streaming/runtime/StreamExecution.scala | 12 ++- .../streaming/runtime/StreamProgress.scala | 40 +++++++- ...ressTrackingMicroBatchExecutionSuite.scala | 4 +- .../streaming/OffsetSeqLogSuite.scala | 36 +++++-- 13 files changed, 282 insertions(+), 79 deletions(-) diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e619adfce17b4..f4bd782617ae2 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2ScanRelation import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeq +import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqBase import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.runtime.{MicroBatchExecution, StreamExecution, StreamingExecutionRelation} import org.apache.spark.sql.execution.streaming.runtime.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED} @@ -854,7 +854,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with true }, AssertOnQuery { q => - val latestOffset: Option[(Long, OffsetSeq)] = q.offsetLog.getLatest() + val latestOffset: Option[(Long, OffsetSeqBase)] = q.offsetLog.getLatest() latestOffset.exists { offset => !offset._2.offsets.exists(_.exists(_.json == "{}")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index fea4d345b8d04..886c0f1b7c44f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -125,12 +125,12 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val offsetLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).offsetLog offsetLog.get(batchId) match { case Some(value) => - val metadata = value.metadata.getOrElse( + val metadata = value.metadataOpt.getOrElse( throw StateDataSourceErrors.offsetMetadataLogUnavailable(batchId, checkpointLocation) ) val clonedSqlConf = session.sessionState.conf.clone() - OffsetSeqMetadata.setSessionConf(metadata, clonedSqlConf) + OffsetSeqMetadata.setSessionConf(metadata.asInstanceOf[OffsetSeqMetadata], clonedSqlConf) StateStoreConf(clonedSqlConf) case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncOffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncOffsetSeqLog.scala index 18d18e61da475..e6ba644ed4833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncOffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncOffsetSeqLog.scala @@ -81,7 +81,7 @@ class AsyncOffsetSeqLog( * the async write of the batch is completed. Future may also be completed exceptionally * to indicate some write error. */ - def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = { + def addAsync(batchId: Long, metadata: OffsetSeqBase): CompletableFuture[(Long, Boolean)] = { require(metadata != null, "'null' metadata cannot written to a metadata log") def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala index 888dc0cdb9120..f56f4ac4673fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala @@ -35,30 +35,67 @@ import org.apache.spark.sql.execution.streaming.runtime.{MultipleWatermarkPolicy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ +trait OffsetSeqBase { + def offsets: Seq[Option[OffsetV2]] -/** - * An ordered collection of offsets, used to track the progress of processing data from one or more - * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance - * vector clock that must progress linearly forward. - */ -case class OffsetSeq(offsets: Seq[Option[OffsetV2]], metadata: Option[OffsetSeqMetadata] = None) { + def metadataOpt: Option[OffsetSeqMetadataBase] + + override def toString: String = this match { + case offsetMap: OffsetMap => + offsetMap.offsetsMap.map { case (sourceId, offsetOpt) => + s"$sourceId: ${offsetOpt.map(_.json).getOrElse("-")}" + }.mkString("{", ", ", "}") + case _ => + offsets.map(_.map(_.json).getOrElse("-")).mkString("[", ", ", "]") + } /** - * Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of - * sources. + * Unpacks an offset into [[StreamProgress]] by associating each offset with the + * ordered list of sources. * - * This method is typically used to associate a serialized offset with actual sources (which - * cannot be serialized). + * This method is typically used to associate a serialized offset with actual + * sources (which cannot be serialized). */ def toStreamProgress(sources: Seq[SparkDataStream]): StreamProgress = { + assert(!this.isInstanceOf[OffsetMap], "toStreamProgress must be called with map") assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + - s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + - s"Cannot continue.") + s"checkpoint offsets and now there are [${sources.size}] sources requested by " + + s"the query. Cannot continue.") new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } - override def toString: String = - offsets.map(_.map(_.json).getOrElse("-")).mkString("[", ", ", "]") + /** + * Converts OffsetMap to StreamProgress using source ID mapping. + * This method is specific to OffsetMap and requires a mapping from sourceId to SparkDataStream. + */ + def toStreamProgress( + sources: Seq[SparkDataStream], + sourceIdToSourceMap: Map[String, SparkDataStream]): StreamProgress = { + this match { + case offsetMap: OffsetMap => + val streamProgressEntries = for { + (sourceId, offsetOpt) <- offsetMap.offsetsMap + offset <- offsetOpt + source <- sourceIdToSourceMap.get(sourceId) + } yield source -> offset + new StreamProgress ++ streamProgressEntries + case _ => + // Fallback to original method for backward compatibility + toStreamProgress(sources) + } + } +} + +/** + * An ordered collection of offsets, used to track the progress of processing data from one or more + * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance + * vector clock that must progress linearly forward. + */ +case class OffsetSeq( + offsets: Seq[Option[OffsetV2]], + metadata: Option[OffsetSeqMetadata] = None) extends OffsetSeqBase { + + override def metadataOpt: Option[OffsetSeqMetadataBase] = metadata } object OffsetSeq { @@ -78,6 +115,30 @@ object OffsetSeq { } } +trait OffsetSeqMetadataBase extends Serializable { + def batchWatermarkMs: Long + def batchTimestampMs: Long + def conf: Map[String, String] + def json: String + def version: Int +} + +/** + * A map-based collection of offsets, used to track the progress of processing data from one or more + * streaming sources. Each source is identified by a string key (initially sourceId.toString()). + * This replaces the sequence-based approach with a more flexible map-based approach to support + * named source identities. + */ +case class OffsetMap( + offsetsMap: Map[String, Option[OffsetV2]], + metadataOpt: Option[OffsetSeqMetadata] = None) extends OffsetSeqBase { + + // OffsetMap does not support sequence-based access + override def offsets: Seq[Option[OffsetV2]] = { + throw new UnsupportedOperationException( + "OffsetMap does not support sequence-based offsets access. Use offsetsMap directly.") + } +} /** * Contains metadata associated with a [[OffsetSeq]]. This information is @@ -97,7 +158,8 @@ object OffsetSeq { case class OffsetSeqMetadata( batchWatermarkMs: Long = 0, batchTimestampMs: Long = 0, - conf: Map[String, String] = Map.empty) { + conf: Map[String, String] = Map.empty) extends OffsetSeqMetadataBase { + override def version: Int = 1 def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala index 816563b3f09fd..f2a8af7ab5c11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala @@ -43,22 +43,26 @@ import org.apache.spark.sql.execution.streaming.runtime.SerializedOffset * - // No offset for this source i.e., an invalid JSON string * {2} // LongOffset 2 * ... + * + * Version 2 format (OffsetMap): + * v2 // version 2 + * metadata + * 0:{0} // sourceId:offset + * 1:{3} // sourceId:offset + * ... */ class OffsetSeqLog(sparkSession: SparkSession, path: String) - extends HDFSMetadataLog[OffsetSeq](sparkSession, path) { + extends HDFSMetadataLog[OffsetSeqBase](sparkSession, path) { - override protected def deserialize(in: InputStream): OffsetSeq = { + override protected def deserialize(in: InputStream): OffsetSeqBase = { // called inside a try-finally where the underlying stream is closed in the caller - def parseOffset(value: String): OffsetV2 = value match { - case OffsetSeqLog.SERIALIZED_VOID_OFFSET => null - case json => SerializedOffset(json) - } val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - validateVersion(lines.next(), OffsetSeqLog.VERSION) + val versionStr = lines.next() + val versionInt = validateVersion(versionStr, OffsetSeqLog.MAX_VERSION) // read metadata val metadata = lines.next().trim match { @@ -66,33 +70,82 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) case md => Some(md) } import org.apache.spark.util.ArrayImplicits._ - OffsetSeq.fill(metadata, lines.map(parseOffset).toArray.toImmutableArraySeq: _*) + if (versionInt == OffsetSeqLog.VERSION_2) { + // deserialize the remaining lines into the offset map + val remainingLines = lines.toArray + // New OffsetMap format: sourceId:offset + val offsetsMap = remainingLines.map { line => + val colonIndex = line.indexOf(':') + if (colonIndex == -1) { + throw new IllegalStateException(s"Invalid OffsetMap format: $line") + } + val sourceId = line.substring(0, colonIndex) + val offsetStr = line.substring(colonIndex + 1) + val offset = if (offsetStr == OffsetSeqLog.SERIALIZED_VOID_OFFSET) { + None + } else { + Some(OffsetSeqLog.parseOffset(offsetStr)) + } + sourceId -> offset + }.toMap + OffsetMap(offsetsMap, metadata.map(OffsetSeqMetadata.apply)) + } else { + OffsetSeq.fill(metadata, + lines.map(OffsetSeqLog.parseOffset).toArray.toImmutableArraySeq: _*) + } } - override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = { + override protected def serialize(offsetSeq: OffsetSeqBase, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(("v" + OffsetSeqLog.VERSION).getBytes(UTF_8)) + out.write(("v" + offsetSeq.metadataOpt.map(_.version).getOrElse(OffsetSeqLog.VERSION_1)) + .getBytes(UTF_8)) // write metadata out.write('\n') - out.write(offsetSeq.metadata.map(_.json).getOrElse("").getBytes(UTF_8)) + out.write(offsetSeq.metadataOpt.map(_.json).getOrElse("").getBytes(UTF_8)) - // write offsets, one per line - offsetSeq.offsets.map(_.map(_.json)).foreach { offset => - out.write('\n') - offset match { - case Some(json: String) => out.write(json.getBytes(UTF_8)) - case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8)) - } + offsetSeq match { + case offsetMap: OffsetMap => + // For OffsetMap, write sourceId:offset pairs, one per line + offsetMap.offsetsMap.foreach { case (sourceId, offsetOpt) => + out.write('\n') + out.write(sourceId.getBytes(UTF_8)) + out.write(':') + offsetOpt match { + case Some(offset) => out.write(offset.json.getBytes(UTF_8)) + case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8)) + } + } + case _ => + // Original sequence-based serialization + offsetSeq.offsets.map(_.map(_.json)).foreach { offset => + out.write('\n') + offset match { + case Some(json: String) => out.write(json.getBytes(UTF_8)) + case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8)) + } + } } } def offsetSeqMetadataForBatchId(batchId: Long): Option[OffsetSeqMetadata] = { - if (batchId < 0) None else get(batchId).flatMap(_.metadata) + if (batchId < 0) { + None + } else { + get(batchId).flatMap(_.metadataOpt.map(_.asInstanceOf[OffsetSeqMetadata])) + } } } object OffsetSeqLog { - private[streaming] val VERSION = 1 - private val SERIALIZED_VOID_OFFSET = "-" + private[streaming] val VERSION_1 = 1 + private[streaming] val VERSION_2 = 2 + private[streaming] val VERSION = VERSION_1 // Default version for backward compatibility + private[streaming] val MAX_VERSION = VERSION_2 + private[streaming] val SERIALIZED_VOID_OFFSET = "-" + + private[checkpointing] def parseOffset(value: String): OffsetV2 = value match { + case SERIALIZED_VOID_OFFSET => null + case json => SerializedOffset(json) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 51cd457fbc856..ad7e9f3e4aa96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution -import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit} +import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.SQLExecution @@ -65,6 +65,8 @@ class ContinuousExecution( @volatile protected var sources: Seq[ContinuousStream] = Seq() + def sourceToIdMap: Map[SparkDataStream, String] = Map.empty + // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ @@ -186,7 +188,7 @@ class ContinuousExecution( val nextOffsets = offsetLog.get(latestEpochId).getOrElse { throw new IllegalStateException( s"Batch $latestEpochId was committed without end epoch offsets!") - } + }.asInstanceOf[OffsetSeq] committedOffsets = nextOffsets.toStreamProgress(sources) execCtx.batchId = latestEpochId + 1 @@ -210,7 +212,8 @@ class ContinuousExecution( val execCtx = latestExecutionContext if (execCtx.batchId > 0) { - AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(Some(offsets), sources) + AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources( + Some(offsets), sources, Map.empty[String, SparkDataStream]) } val withNewSources: LogicalPlan = logicalPlan transform { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AcceptsLatestSeenOffsetHandler.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AcceptsLatestSeenOffsetHandler.scala index b15b93b47ada4..3eb5e6eb7d70d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AcceptsLatestSeenOffsetHandler.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AcceptsLatestSeenOffsetHandler.scala @@ -20,18 +20,19 @@ package org.apache.spark.sql.execution.streaming.runtime import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, SparkDataStream} import org.apache.spark.sql.execution.streaming.Source -import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeq +import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqBase /** * This feeds "latest seen offset" to the sources that implement AcceptsLatestSeenOffset. */ object AcceptsLatestSeenOffsetHandler { def setLatestSeenOffsetOnSources( - offsets: Option[OffsetSeq], - sources: Seq[SparkDataStream]): Unit = { + offsets: Option[OffsetSeqBase], + sources: Seq[SparkDataStream], + sourceIdMap: Map[String, SparkDataStream]): Unit = { assertNoAcceptsLatestSeenOffsetWithDataSourceV1(sources) - offsets.map(_.toStreamProgress(sources)) match { + offsets.map(_.toStreamProgress(sources, sourceIdMap)) match { case Some(streamProgress) => streamProgress.foreach { case (src: AcceptsLatestSeenOffset, offset) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AsyncProgressTrackingMicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AsyncProgressTrackingMicroBatchExecution.scala index 2a87ba3380883..4168df2e1f516 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AsyncProgressTrackingMicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AsyncProgressTrackingMicroBatchExecution.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.streaming.WriteToStream import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, OneTimeTrigger, ProcessingTimeTrigger} -import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, CommitMetadata, OffsetSeq} +import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, CommitMetadata, OffsetSeqBase} import org.apache.spark.sql.execution.streaming.operators.stateful.StateStoreWriter import org.apache.spark.sql.streaming.Trigger import org.apache.spark.util.{Clock, ThreadUtils} @@ -49,7 +49,7 @@ class AsyncProgressTrackingMicroBatchExecution( // Offsets that are ready to be committed by the source. // This is needed so that we can call source commit in the same thread as micro-batch execution // to be thread safe - private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]() + private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeqBase]() // to cache the batch id of the last batch written to storage private val lastBatchPersistedToDurableStorage = new AtomicLong(-1) @@ -104,7 +104,7 @@ class AsyncProgressTrackingMicroBatchExecution( // perform quick validation to fail faster validateAndGetTrigger() - override def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeq] = { + override def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeqBase] = { /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. * The offset log may not be contiguous */ @@ -137,14 +137,15 @@ class AsyncProgressTrackingMicroBatchExecution( // Because we are using a thread pool with only one thread, async writes to the offset log // are still written in a serial / in order fashion offsetLog - .addAsync(execCtx.batchId, execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata)) - .thenAccept(tuple => { - val (batchId, persistedToDurableStorage) = tuple + .addAsync(execCtx.batchId, + execCtx.endOffsets.toOffsets(sources, sourceIdMap, execCtx.offsetSeqMetadata)) + .thenAccept((tuple: (Long, Boolean)) => { + val (batchId: Long, persistedToDurableStorage: Boolean) = tuple if (persistedToDurableStorage) { // batch id cache not initialized if (lastBatchPersistedToDurableStorage.get == -1) { lastBatchPersistedToDurableStorage.set( - offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1)) + offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1L)) } if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala index cf2fca3d3cd8b..e08f9a2a52f5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, RealTimeStreamScanExec, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper} -import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeq, OffsetSeqMetadata} +import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeqBase, OffsetSeqMetadata} import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, StateStoreWriter} import org.apache.spark.sql.execution.streaming.runtime.AcceptsLatestSeenOffsetHandler import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} @@ -100,6 +100,14 @@ class MicroBatchExecution( @volatile protected var sources: Seq[SparkDataStream] = Seq.empty + // Source ID mapping for OffsetMap support + // Using index as sourceId initially, can be extended to support user-provided names + // This is initialized in the same path as the sources Seq (defined above) and is used + // in the same way, when OffsetLog v2 is used. + @volatile protected var sourceIdMap: Map[String, SparkDataStream] = Map.empty + + override protected def sourceToIdMap: Map[SparkDataStream, String] = sourceIdMap.map(_.swap) + @volatile protected[sql] var triggerExecutor: TriggerExecutor = _ protected def getTrigger(): TriggerExecutor = { @@ -243,6 +251,11 @@ class MicroBatchExecution( case r: StreamingDataSourceV2ScanRelation => r.stream } + // Create source ID mapping for OffsetMap support + sourceIdMap = sources.zipWithIndex.map { + case (source, index) => index.toString -> source + }.toMap + // Inform the source if it is in real time mode if (trigger.isInstanceOf[RealTimeTrigger]) { sources.foreach{ @@ -399,7 +412,10 @@ class MicroBatchExecution( } AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources( - offsetLog.getLatest().map(_._2), sources) + offsetLog.getLatest().map(_._2), + sources, + sourceIdMap + ) val execCtx = new MicroBatchExecutionContext(id, runId, name, triggerClock, sources, sink, progressReporter, -1, sparkSession, None) @@ -552,7 +568,7 @@ class MicroBatchExecution( * @param latestBatchId the batch id of the current micro batch * @return A option that contains the offset of the previously written batch */ - def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeq] = { + def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeqBase] = { if (latestBatchId != 0) { Some(offsetLog.get(latestBatchId - 1).getOrElse { logError(log"The offset log for batch ${MDC(LogKeys.BATCH_ID, latestBatchId - 1)} " + @@ -601,17 +617,18 @@ class MicroBatchExecution( * in the offset log */ execCtx.batchId = latestBatchId execCtx.isCurrentBatchConstructed = true - execCtx.endOffsets = nextOffsets.toStreamProgress(sources) + execCtx.endOffsets = nextOffsets.toStreamProgress(sources, sourceIdMap) // validate the integrity of offset log and get the previous offset from the offset log val secondLatestOffsets = validateOffsetLogAndGetPrevOffset(latestBatchId) secondLatestOffsets.foreach { offset => - execCtx.startOffsets = offset.toStreamProgress(sources) + execCtx.startOffsets = offset.toStreamProgress(sources, sourceIdMap) } // update offset metadata - nextOffsets.metadata.foreach { metadata => - OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.sessionState.conf) + nextOffsets.metadataOpt.foreach { metadata => + OffsetSeqMetadata.setSessionConf( + metadata.asInstanceOf[OffsetSeqMetadata], sparkSessionToRunBatches.sessionState.conf) execCtx.offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf, logicalPlan) @@ -846,8 +863,8 @@ class MicroBatchExecution( shouldConstructNextBatch } - protected def commitSources(offsetSeq: OffsetSeq): Unit = { - offsetSeq.toStreamProgress(sources).foreach { + protected def commitSources(offsetSeq: OffsetSeqBase): Unit = { + offsetSeq.toStreamProgress(sources, sourceIdMap).foreach { case (src: Source, off: Offset) => src.commit(off) case (stream: MicroBatchStream, off) => stream.commit(stream.deserializeOffset(off.json)) @@ -1106,7 +1123,7 @@ class MicroBatchExecution( if (!trigger.isInstanceOf[RealTimeTrigger]) { if (!offsetLog.add( execCtx.batchId, - execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata) + execCtx.endOffsets.toOffsets(sources, sourceIdMap, execCtx.offsetSeqMetadata) )) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } @@ -1262,7 +1279,7 @@ class MicroBatchExecution( execCtx.reportTimeTaken("walCommit") { if (!offsetLog.add( execCtx.batchId, - execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata) + execCtx.endOffsets.toOffsets(sources, sourceIdMap, execCtx.offsetSeqMetadata) )) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamExecution.scala index 56ed0de1fcdc6..65c1226a85dff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamExecution.scala @@ -154,6 +154,12 @@ abstract class StreamExecution( */ protected def sources: Seq[SparkDataStream] + /** + * Source-to-ID mapping for OffsetMap support. + * Using index as sourceId initially, can be extended to support user-provided names. + */ + protected def sourceToIdMap: Map[SparkDataStream, String] + /** Isolated spark session to run the batches with. */ protected[sql] val sparkSessionForStream: SparkSession = sparkSession.cloneSession() @@ -370,10 +376,12 @@ abstract class StreamExecution( toDebugString(includeLogicalPlan = isInitialized), cause = cause, getLatestExecutionContext().startOffsets - .toOffsetSeq(sources.toSeq, getLatestExecutionContext().offsetSeqMetadata) + .toOffsets(sources.toSeq, sourceToIdMap.map(_.swap), + getLatestExecutionContext().offsetSeqMetadata) .toString, getLatestExecutionContext().endOffsets - .toOffsetSeq(sources.toSeq, getLatestExecutionContext().offsetSeqMetadata) + .toOffsets(sources.toSeq, sourceToIdMap.map(_.swap), + getLatestExecutionContext().offsetSeqMetadata) .toString, errorClass = "STREAM_FAILED", messageParameters = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala index a6fd103e8d6a3..15f91fc6a2ef2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.streaming.runtime import scala.collection.immutable import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkDataStream} -import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetSeq, OffsetSeqMetadata} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetMap, OffsetSeq, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataBase} /** * A helper class that looks like a Map[Source, Offset]. @@ -30,6 +31,43 @@ class StreamProgress( new immutable.HashMap[SparkDataStream, OffsetV2]) extends scala.collection.immutable.Map[SparkDataStream, OffsetV2] { + /** + * Unified method to convert StreamProgress to appropriate OffsetSeq format. + * Handles both VERSION_1 (OffsetSeq) and VERSION_2 (OffsetMap) based on metadata version. + */ + def toOffsets( + sources: Seq[SparkDataStream], + sourceIdMap: Map[String, SparkDataStream], + metadata: OffsetSeqMetadataBase): OffsetSeqBase = { + metadata.version match { + case OffsetSeqLog.VERSION_1 => + toOffsetSeq(sources, metadata) + case OffsetSeqLog.VERSION_2 => + toOffsetMap(sourceIdMap, metadata) + case v => + throw QueryExecutionErrors.logVersionGreaterThanSupported(v, OffsetSeqLog.MAX_VERSION) + } + } + + private def toOffsetSeq( + source: Seq[SparkDataStream], + metadata: OffsetSeqMetadataBase): OffsetSeqBase = { + OffsetSeq(source.map(get), Some(metadata.asInstanceOf[OffsetSeqMetadata])) + } + + private def toOffsetMap( + sourceIdMap: Map[String, SparkDataStream], + metadata: OffsetSeqMetadataBase): OffsetMap = { + // Compute reverse mapping only when needed + val sourceToIdMap = sourceIdMap.map(_.swap) + val offsetsMap = baseMap.map { case (source, offset) => + val sourceId = sourceToIdMap.getOrElse(source, + throw new IllegalArgumentException(s"Source $source not found in sourceToIdMap")) + sourceId -> Some(offset) + } + OffsetMap(offsetsMap, Some(metadata.asInstanceOf[OffsetSeqMetadata])) + } + def toOffsetSeq(source: Seq[SparkDataStream], metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala index e31e0e70cf39c..3f0e652646626 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.time.{Seconds, Span} import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.connector.read.streaming -import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, OffsetSeq} +import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog} import org.apache.spark.sql.execution.streaming.runtime.{AsyncProgressTrackingMicroBatchExecution, MemoryStream, StreamExecution} import org.apache.spark.sql.execution.streaming.runtime.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK} import org.apache.spark.sql.functions.{column, window} @@ -835,7 +835,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0) // commits received at source should match up to the ones found in the offset log for (i <- 0 until inputData.commits.length) { - val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get + val offsetOnDisk = offsetLog.get(offsetLogFiles(i)).get val sourceCommittedOffset: streaming.Offset = inputData.commits(i) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index e4312fd16d1fa..1f9160c9ac59e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import java.io.File import org.apache.spark.sql.catalyst.util.stringToFile -import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetSeq, OffsetSeqLog, OffsetSeqMetadata} +import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetMap, OffsetSeq, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata} import org.apache.spark.sql.execution.streaming.runtime.{LongOffset, SerializedOffset} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -109,7 +109,7 @@ class OffsetSeqLogSuite extends SharedSparkSession { log.get(0) } Seq( - s"maximum supported log version is v${OffsetSeqLog.VERSION}, but encountered v99999", + s"maximum supported log version is v${OffsetSeqLog.MAX_VERSION}, but encountered v99999", "produced by a newer version of Spark and cannot be read by this version" ).foreach { message => assert(e.getMessage.contains(message)) @@ -124,10 +124,10 @@ class OffsetSeqLogSuite extends SharedSparkSession { Some(SerializedOffset("""{"logOffset":345}""")), Some(SerializedOffset("""{"topic-0":{"0":1}}""")) )) - assert(offsetSeq.metadata === Some(OffsetSeqMetadata(0L, 1480981499528L))) + assert(offsetSeq.metadataOpt === Some(OffsetSeqMetadata(0L, 1480981499528L))) } - private def readFromResource(dir: String): (Long, OffsetSeq) = { + private def readFromResource(dir: String): (Long, OffsetSeqBase) = { val input = getClass.getResource(s"/structured-streaming/$dir") val log = new OffsetSeqLog(spark, input.toString) log.getLatest().get @@ -161,7 +161,7 @@ class OffsetSeqLogSuite extends SharedSparkSession { // Read the latest offset log val offsetSeq = log.get(latestBatchId.get).get - val offsetSeqMetadata = offsetSeq.metadata.get + val offsetSeqMetadata = offsetSeq.metadataOpt.get if (entryExists) { val encodingFormatOpt = offsetSeqMetadata.conf.get( @@ -171,7 +171,8 @@ class OffsetSeqLogSuite extends SharedSparkSession { } val clonedSqlConf = spark.sessionState.conf.clone() - OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf) + OffsetSeqMetadata.setSessionConf( + offsetSeqMetadata.asInstanceOf[OffsetSeqMetadata], clonedSqlConf) assert(clonedSqlConf.stateStoreEncodingFormat == encodingFormat) } @@ -210,13 +211,32 @@ class OffsetSeqLogSuite extends SharedSparkSession { withSQLConf(rowChecksumConf -> true.toString) { val existingChkpt = "offset-log-version-2.1.0" val (_, offsetSeq) = readFromResource(existingChkpt) - val offsetSeqMetadata = offsetSeq.metadata.get + val offsetSeqMetadata = offsetSeq.metadataOpt.get // Not present in existing checkpoint assert(offsetSeqMetadata.conf.get(rowChecksumConf) === None) val clonedSqlConf = spark.sessionState.conf.clone() - OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf) + OffsetSeqMetadata.setSessionConf( + offsetSeqMetadata.asInstanceOf[OffsetSeqMetadata], clonedSqlConf) assert(!clonedSqlConf.stateStoreRowChecksumEnabled) } } + + test("OffsetMap golden file compatibility test - VERSION_2 format") { + val (batchId, offsetSeq) = readFromResource("offset-map") + assert(batchId === 3) + + // Verify it's an OffsetMap (VERSION_2) + assert(offsetSeq.isInstanceOf[OffsetMap]) + val offsetMap = offsetSeq.asInstanceOf[OffsetMap] + + // Verify the offset data + assert(offsetMap.offsetsMap === Map("0" -> Some(SerializedOffset("3")))) + + // Verify metadata + assert(offsetSeq.metadataOpt.isDefined) + val metadata = offsetSeq.metadataOpt.get.asInstanceOf[OffsetSeqMetadata] + assert(metadata.batchWatermarkMs === 0) + assert(metadata.batchTimestampMs === 1758651405232L) + } } From 4ff4eb5dd160218c2c29f58b4f79e9c8932df698 Mon Sep 17 00:00:00 2001 From: ericm-db Date: Tue, 18 Nov 2025 17:26:27 -0800 Subject: [PATCH 2/5] compilation --- .../state/OfflineStateRepartitionRunner.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala index 2456b2c9b73b6..6880cbeddc141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala @@ -129,7 +129,8 @@ class OfflineStateRepartitionRunner( // If it is a failed repartition batch, lets check if the shuffle partitions // is the same as the requested. If same, then we can retry the batch. val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get - val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadata.get).get + val lastBatchShufflePartitions = getShufflePartitions( + lastBatch.metadataOpt.get.asInstanceOf[OffsetSeqMetadata]).get if (lastBatchShufflePartitions == numPartitions) { // We can retry the repartition batch. logInfo(log"The last batch is a failed repartition batch " + @@ -193,10 +194,10 @@ class OfflineStateRepartitionRunner( .offsetSeqNotFoundError(checkpointLocation, lastCommittedBatchId)) // Missing offset metadata not supported - val lastCommittedMetadata = lastCommittedOffsetSeq.metadata.getOrElse( + val lastCommittedMetadata = lastCommittedOffsetSeq.metadataOpt.getOrElse( throw OfflineStateRepartitionErrors.missingOffsetSeqMetadataError( checkpointLocation, version = 1, batchId = lastCommittedBatchId) - ) + ).asInstanceOf[OffsetSeqMetadata] // No-op if the number of shuffle partitions in last commit is the same as the requested. if (getShufflePartitions(lastCommittedMetadata).get == numPartitions) { @@ -253,13 +254,15 @@ object OfflineStateRepartitionUtils { throw OfflineStateRepartitionErrors .offsetSeqNotFoundError(checkpointLocation, prevBatchId)) - val batchMetadata = batch.metadata.getOrElse(throw OfflineStateRepartitionErrors + val batchMetadata = batch.metadataOpt.getOrElse(throw OfflineStateRepartitionErrors .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = batchId)) + .asInstanceOf[OffsetSeqMetadata] val shufflePartitions = getShufflePartitions(batchMetadata).get - val previousBatchMetadata = previousBatch.metadata.getOrElse( + val previousBatchMetadata = previousBatch.metadataOpt.getOrElse( throw OfflineStateRepartitionErrors .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = prevBatchId)) + .asInstanceOf[OffsetSeqMetadata] val previousShufflePartitions = getShufflePartitions(previousBatchMetadata).get previousShufflePartitions != shufflePartitions From d7eb5814260b51706875e095e7391be20cb04a84 Mon Sep 17 00:00:00 2001 From: ericm-db Date: Wed, 19 Nov 2025 11:42:12 -0800 Subject: [PATCH 3/5] compiles --- .../streaming/state/OfflineStateRepartitionSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala index 86b5502b652ed..f8b1f89b8d6d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata} +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata, OffsetSeqMetadata} import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryCheckpointMetadata} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ @@ -258,7 +258,8 @@ class OfflineStateRepartitionSuite extends StreamTest { assert(lastBatchId == batchId) val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get - val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadata.get).get + val lastBatchShufflePartitions = getShufflePartitions( + lastBatch.metadataOpt.get.asInstanceOf[OffsetSeqMetadata]).get assert(lastBatchShufflePartitions == expectedShufflePartitions) // Verify the commit log @@ -277,7 +278,8 @@ class OfflineStateRepartitionSuite extends StreamTest { s"Offsets should be identical between batch $previousBatchId and $batchId") // Verify metadata is the same except for shuffle partitions config - (lastBatch.metadata, previousBatch.metadata) match { + (lastBatch.metadataOpt.map(_.asInstanceOf[OffsetSeqMetadata]), + previousBatch.metadataOpt.map(_.asInstanceOf[OffsetSeqMetadata])) match { case (Some(lastMetadata), Some(previousMetadata)) => // Check watermark and timestamp are the same assert(lastMetadata.batchWatermarkMs == previousMetadata.batchWatermarkMs, From c2a6a03c01e46e4374d11f058b89ba06fa023955 Mon Sep 17 00:00:00 2001 From: ericm-db Date: Wed, 19 Nov 2025 12:03:24 -0800 Subject: [PATCH 4/5] removed base trait --- .../datasources/v2/state/StateDataSource.scala | 2 +- .../streaming/checkpointing/OffsetSeq.scala | 15 ++++----------- .../streaming/checkpointing/OffsetSeqLog.scala | 2 +- .../runtime/MicroBatchExecution.scala | 3 +-- .../streaming/runtime/StreamProgress.scala | 18 +++++++----------- .../state/OfflineStateRepartitionRunner.scala | 6 ++---- .../streaming/OffsetSeqLogSuite.scala | 8 +++----- .../state/OfflineStateRepartitionSuite.scala | 8 +++----- 8 files changed, 22 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 459c4025c27d4..c97a70eb3c8ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -131,7 +131,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging ) val clonedSqlConf = session.sessionState.conf.clone() - OffsetSeqMetadata.setSessionConf(metadata.asInstanceOf[OffsetSeqMetadata], clonedSqlConf) + OffsetSeqMetadata.setSessionConf(metadata, clonedSqlConf) StateStoreConf(clonedSqlConf) case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala index f56f4ac4673fd..a882d9539c4bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.internal.SQLConf._ trait OffsetSeqBase { def offsets: Seq[Option[OffsetV2]] - def metadataOpt: Option[OffsetSeqMetadataBase] + def metadataOpt: Option[OffsetSeqMetadata] override def toString: String = this match { case offsetMap: OffsetMap => @@ -95,7 +95,7 @@ case class OffsetSeq( offsets: Seq[Option[OffsetV2]], metadata: Option[OffsetSeqMetadata] = None) extends OffsetSeqBase { - override def metadataOpt: Option[OffsetSeqMetadataBase] = metadata + override def metadataOpt: Option[OffsetSeqMetadata] = metadata } object OffsetSeq { @@ -115,13 +115,6 @@ object OffsetSeq { } } -trait OffsetSeqMetadataBase extends Serializable { - def batchWatermarkMs: Long - def batchTimestampMs: Long - def conf: Map[String, String] - def json: String - def version: Int -} /** * A map-based collection of offsets, used to track the progress of processing data from one or more @@ -158,8 +151,8 @@ case class OffsetMap( case class OffsetSeqMetadata( batchWatermarkMs: Long = 0, batchTimestampMs: Long = 0, - conf: Map[String, String] = Map.empty) extends OffsetSeqMetadataBase { - override def version: Int = 1 + conf: Map[String, String] = Map.empty, + version: Int = 1) { def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala index f2a8af7ab5c11..891a66b21b52e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala @@ -132,7 +132,7 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) if (batchId < 0) { None } else { - get(batchId).flatMap(_.metadataOpt.map(_.asInstanceOf[OffsetSeqMetadata])) + get(batchId).flatMap(_.metadataOpt) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala index e08f9a2a52f5e..5ea97e6a2c32d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala @@ -627,8 +627,7 @@ class MicroBatchExecution( // update offset metadata nextOffsets.metadataOpt.foreach { metadata => - OffsetSeqMetadata.setSessionConf( - metadata.asInstanceOf[OffsetSeqMetadata], sparkSessionToRunBatches.sessionState.conf) + OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.sessionState.conf) execCtx.offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala index 15f91fc6a2ef2..0708f931b77e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala @@ -21,7 +21,7 @@ import scala.collection.immutable import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetMap, OffsetSeq, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataBase} +import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetMap, OffsetSeq, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata} /** * A helper class that looks like a Map[Source, Offset]. @@ -38,7 +38,7 @@ class StreamProgress( def toOffsets( sources: Seq[SparkDataStream], sourceIdMap: Map[String, SparkDataStream], - metadata: OffsetSeqMetadataBase): OffsetSeqBase = { + metadata: OffsetSeqMetadata): OffsetSeqBase = { metadata.version match { case OffsetSeqLog.VERSION_1 => toOffsetSeq(sources, metadata) @@ -49,15 +49,15 @@ class StreamProgress( } } - private def toOffsetSeq( + def toOffsetSeq( source: Seq[SparkDataStream], - metadata: OffsetSeqMetadataBase): OffsetSeqBase = { - OffsetSeq(source.map(get), Some(metadata.asInstanceOf[OffsetSeqMetadata])) + metadata: OffsetSeqMetadata): OffsetSeq = { + OffsetSeq(source.map(get), Some(metadata)) } private def toOffsetMap( sourceIdMap: Map[String, SparkDataStream], - metadata: OffsetSeqMetadataBase): OffsetMap = { + metadata: OffsetSeqMetadata): OffsetMap = { // Compute reverse mapping only when needed val sourceToIdMap = sourceIdMap.map(_.swap) val offsetsMap = baseMap.map { case (source, offset) => @@ -65,11 +65,7 @@ class StreamProgress( throw new IllegalArgumentException(s"Source $source not found in sourceToIdMap")) sourceId -> Some(offset) } - OffsetMap(offsetsMap, Some(metadata.asInstanceOf[OffsetSeqMetadata])) - } - - def toOffsetSeq(source: Seq[SparkDataStream], metadata: OffsetSeqMetadata): OffsetSeq = { - OffsetSeq(source.map(get), Some(metadata)) + OffsetMap(offsetsMap, Some(metadata)) } override def toString: String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala index 6880cbeddc141..63e3a6ec8d9a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala @@ -130,7 +130,7 @@ class OfflineStateRepartitionRunner( // is the same as the requested. If same, then we can retry the batch. val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get val lastBatchShufflePartitions = getShufflePartitions( - lastBatch.metadataOpt.get.asInstanceOf[OffsetSeqMetadata]).get + lastBatch.metadataOpt.get).get if (lastBatchShufflePartitions == numPartitions) { // We can retry the repartition batch. logInfo(log"The last batch is a failed repartition batch " + @@ -197,7 +197,7 @@ class OfflineStateRepartitionRunner( val lastCommittedMetadata = lastCommittedOffsetSeq.metadataOpt.getOrElse( throw OfflineStateRepartitionErrors.missingOffsetSeqMetadataError( checkpointLocation, version = 1, batchId = lastCommittedBatchId) - ).asInstanceOf[OffsetSeqMetadata] + ) // No-op if the number of shuffle partitions in last commit is the same as the requested. if (getShufflePartitions(lastCommittedMetadata).get == numPartitions) { @@ -256,13 +256,11 @@ object OfflineStateRepartitionUtils { val batchMetadata = batch.metadataOpt.getOrElse(throw OfflineStateRepartitionErrors .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = batchId)) - .asInstanceOf[OffsetSeqMetadata] val shufflePartitions = getShufflePartitions(batchMetadata).get val previousBatchMetadata = previousBatch.metadataOpt.getOrElse( throw OfflineStateRepartitionErrors .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = prevBatchId)) - .asInstanceOf[OffsetSeqMetadata] val previousShufflePartitions = getShufflePartitions(previousBatchMetadata).get previousShufflePartitions != shufflePartitions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index 1f9160c9ac59e..2c3ae11a4e7ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -171,8 +171,7 @@ class OffsetSeqLogSuite extends SharedSparkSession { } val clonedSqlConf = spark.sessionState.conf.clone() - OffsetSeqMetadata.setSessionConf( - offsetSeqMetadata.asInstanceOf[OffsetSeqMetadata], clonedSqlConf) + OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf) assert(clonedSqlConf.stateStoreEncodingFormat == encodingFormat) } @@ -216,8 +215,7 @@ class OffsetSeqLogSuite extends SharedSparkSession { assert(offsetSeqMetadata.conf.get(rowChecksumConf) === None) val clonedSqlConf = spark.sessionState.conf.clone() - OffsetSeqMetadata.setSessionConf( - offsetSeqMetadata.asInstanceOf[OffsetSeqMetadata], clonedSqlConf) + OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf) assert(!clonedSqlConf.stateStoreRowChecksumEnabled) } } @@ -235,7 +233,7 @@ class OffsetSeqLogSuite extends SharedSparkSession { // Verify metadata assert(offsetSeq.metadataOpt.isDefined) - val metadata = offsetSeq.metadataOpt.get.asInstanceOf[OffsetSeqMetadata] + val metadata = offsetSeq.metadataOpt.get assert(metadata.batchWatermarkMs === 0) assert(metadata.batchTimestampMs === 1758651405232L) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala index f8b1f89b8d6d2..860e7a1ab2e45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata, OffsetSeqMetadata} +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata} import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryCheckpointMetadata} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ @@ -258,8 +258,7 @@ class OfflineStateRepartitionSuite extends StreamTest { assert(lastBatchId == batchId) val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get - val lastBatchShufflePartitions = getShufflePartitions( - lastBatch.metadataOpt.get.asInstanceOf[OffsetSeqMetadata]).get + val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadataOpt.get).get assert(lastBatchShufflePartitions == expectedShufflePartitions) // Verify the commit log @@ -278,8 +277,7 @@ class OfflineStateRepartitionSuite extends StreamTest { s"Offsets should be identical between batch $previousBatchId and $batchId") // Verify metadata is the same except for shuffle partitions config - (lastBatch.metadataOpt.map(_.asInstanceOf[OffsetSeqMetadata]), - previousBatch.metadataOpt.map(_.asInstanceOf[OffsetSeqMetadata])) match { + (lastBatch.metadataOpt, previousBatch.metadataOpt) match { case (Some(lastMetadata), Some(previousMetadata)) => // Check watermark and timestamp are the same assert(lastMetadata.batchWatermarkMs == previousMetadata.batchWatermarkMs, From b435542678d6c0400bd991ed621ce80bf0951a98 Mon Sep 17 00:00:00 2001 From: ericm-db Date: Wed, 19 Nov 2025 15:10:43 -0800 Subject: [PATCH 5/5] oops --- sql/core/src/test/resources/structured-streaming/offset-map/0 | 3 +++ sql/core/src/test/resources/structured-streaming/offset-map/1 | 3 +++ sql/core/src/test/resources/structured-streaming/offset-map/2 | 3 +++ sql/core/src/test/resources/structured-streaming/offset-map/3 | 3 +++ 4 files changed, 12 insertions(+) create mode 100644 sql/core/src/test/resources/structured-streaming/offset-map/0 create mode 100644 sql/core/src/test/resources/structured-streaming/offset-map/1 create mode 100644 sql/core/src/test/resources/structured-streaming/offset-map/2 create mode 100644 sql/core/src/test/resources/structured-streaming/offset-map/3 diff --git a/sql/core/src/test/resources/structured-streaming/offset-map/0 b/sql/core/src/test/resources/structured-streaming/offset-map/0 new file mode 100644 index 0000000000000..ca9c25d8cf3f6 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/offset-map/0 @@ -0,0 +1,3 @@ +v2 +{"batchWatermarkMs":0,"batchTimestampMs":1758651405232,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.shuffle.partitions":"5"}} +0:0 diff --git a/sql/core/src/test/resources/structured-streaming/offset-map/1 b/sql/core/src/test/resources/structured-streaming/offset-map/1 new file mode 100644 index 0000000000000..9e01cb1e2eae6 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/offset-map/1 @@ -0,0 +1,3 @@ +v2 +{"batchWatermarkMs":0,"batchTimestampMs":1758651405232,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.shuffle.partitions":"5"}} +0:1 diff --git a/sql/core/src/test/resources/structured-streaming/offset-map/2 b/sql/core/src/test/resources/structured-streaming/offset-map/2 new file mode 100644 index 0000000000000..833abc798a6a2 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/offset-map/2 @@ -0,0 +1,3 @@ +v2 +{"batchWatermarkMs":0,"batchTimestampMs":1758651405232,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.shuffle.partitions":"5"}} +0:2 diff --git a/sql/core/src/test/resources/structured-streaming/offset-map/3 b/sql/core/src/test/resources/structured-streaming/offset-map/3 new file mode 100644 index 0000000000000..f108ad977be0f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/offset-map/3 @@ -0,0 +1,3 @@ +v2 +{"batchWatermarkMs":0,"batchTimestampMs":1758651405232,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.shuffle.partitions":"5"}} +0:3