Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49411][SS] Communicate State Store Checkpoint ID between driver and stateful operators #47895

Closed
wants to merge 24 commits into from

Conversation

siying
Copy link
Contributor

@siying siying commented Aug 27, 2024

What changes were proposed in this pull request?

This is an incremental step to implement RocksDB state store checkpoint format V2.

Once conf STATE_STORE_CHECKPOINT_FORMAT_VERSION is set to be higher than version 2, the executor returns checkpointID to the driver (only done for RocksDB). The driver stores is locally. For the next batch, the State Store Checkpoint ID is sent to the executor to be used to load the state store. If the local version of the executor doesn't match the uniqueID, it will reload from the checkpoint.

There is no behavior change if the default checkpoint format is used.

Why are the changes needed?

This is an incremental step of the project of a new RocksDB State Store checkpoint format. The new format is to simplify checkpoint mechanism to make it less bug prone, and fix some unexpected query results in rare queries.

Does this PR introduce any user-facing change?

No

How was this patch tested?

A new unit test is added to cover format version. And another unit test is added to validate the uniqueID is passed back and force as expected.

Was this patch authored or co-authored using generative AI tooling?

No

@siying siying marked this pull request as draft August 27, 2024 18:33
@siying siying changed the title [WIP] Communicate CheckpointID between driver and stateful operators [SPARK-49411][SS] Communicate CheckpointID between driver and stateful operators Aug 27, 2024
@siying siying marked this pull request as ready for review August 27, 2024 22:22
val isFirstBatch: Boolean)
val isFirstBatch: Boolean,
val currentCheckpointUniqueId:
MutableMap[Long, Array[String]] = MutableMap[Long, Array[String]]())
Copy link
Contributor

@WweiL WweiL Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add comments on what are these unique Ids map to? I believe key is operator Id?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also better name it currentStateUniqueId as it is only related to state store not general checkpoint

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also confused by this. When I sketched an implementation of your proposal in my head, my assumption would be that IncrementalExecution would get just an ID, perhaps a single Long, that would correspond to the ID that it would bake into the physical plan sent to executors. So why is a map needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment, but it is basically operatorID->partitionID->checkpointID

private def updateCheckpointId(
execCtx: MicroBatchExecutionContext,
latestExecPlan: SparkPlan): Unit = {
// This function cannot handle MBP now.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary comment

if (loadedVersion != version) {
if (loadedVersion != version ||
(checkpointFormatVersion >= 2 && checkpointUniqueId.isDefined &&
(!loadedCheckpointId.isDefined || checkpointUniqueId.get != loadedCheckpointId.get))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: loadedCheckpointId.isEmpty

.agg(count("*"))
.as[(Int, Long)]

// Run the stream with changelog checkpointing disabled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

// Store checkpointIDs for state store checkpoints to be committed or have been committed to
// the commit log.
// operatorID -> (partitionID -> uniqueID)
private val currentCheckpointUniqueId = MutableMap[Long, Array[String]]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is better to be put into the stream execution context

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operatorID -> (partitionID -> uniqueID), is this supposed to mean a map of maps? If so, then why is the type of currentCheckpointUniqueId just a single map?

I also don't fully understand why we would need a unique map for every operator X partition. Why is it not sufficient to have the following protocol, where we have one unique ID for every batch:

For the first batch, an ID is created and sent to all executors. When all tasks finish, that ID is persisted into the commit log. It is also kept in memory for the subsequent batch.

For any other batch, if there does not exist an ID in memory from the previous batch, then it must be read from the commit log and brought into memory. (This is the restart case.)

Then, using the ID in memory from the previous batch (call that prevId), this is sent to all executors in the physical plan, as well as a new ID for the current batch (call this currId). Before any processing start, executors must load and use the state for prevId to process the current batch. Then, they can start processing, and they upload their state as <state file name>_currId.<changelog|snapshot>.

What's wrong with that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, the uniqueID is generated in executor. As a potential optimization, the driver can send a uniqueID to all executors, but executors still need to modify it to make it unique among all attempts of the same task. After doing that, the IDs won't be unique anymore, so we need different IDs per partition.

try {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
rocksDB.load(version, true)
rocksDB.load(version, uniqueId, true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocksDB.load(
version,
if (storeConf.stateStoreCheckpointFormatVersion >= 2) uniqueId else None)

@volatile private var LastCommitBasedCheckpointId: Option[String] = None
@volatile private var lastCommittedCheckpointId: Option[String] = None
@volatile private var loadedCheckpointId: Option[String] = None
@volatile private var sessionCheckpointId: Option[String] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should reset these to None in rollback()

@neilramaswamy
Copy link
Contributor

neilramaswamy commented Sep 10, 2024

fix some unexpected query results in rare queries

@siying can you provide some content about which situations there are specifically?

(Edit, seems to be here in the design doc.)

Copy link
Contributor

@neilramaswamy neilramaswamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to stop reviewing since I have a few fundamental questions regarding the protocol.

@@ -105,7 +105,7 @@ class StreamStreamJoinStatePartitionReader(
val stateInfo = StatefulOperatorStateInfo(
partition.sourceOptions.stateCheckpointLocation.toString,
partition.queryId, partition.sourceOptions.operatorId,
partition.sourceOptions.batchId + 1, -1)
partition.sourceOptions.batchId + 1, -1, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this None? I would image that users of the state data source reader now have to specify the id that they would like to read, given that state stores are now not uniquely identified by operator/partition/name, but by id/operator/partition/name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Will check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any update here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any update here?

@neilramaswamy here, we don't know the checkpointID. We would know the ID after we persist to the commit log. But now it is just like the first time we restart the query -- we don't know it. I can leave a TODO.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we load the previous state store correctly in this case then in a stream restart?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code needs to change after we persistent the ID to commit logs. The ID needs to be get from the commit logs and pass it to here. For now, we can say state store reader isn't supported in this new mode (it's likely working accidentally, but it's not worth even testing it). There is already a TODO comment above.

val isFirstBatch: Boolean)
val isFirstBatch: Boolean,
val currentCheckpointUniqueId:
MutableMap[Long, Array[String]] = MutableMap[Long, Array[String]]())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also confused by this. When I sketched an implementation of your proposal in my head, my assumption would be that IncrementalExecution would get just an ID, perhaps a single Long, that would correspond to the ID that it would bake into the physical plan sent to executors. So why is a map needed?

// Store checkpointIDs for state store checkpoints to be committed or have been committed to
// the commit log.
// operatorID -> (partitionID -> uniqueID)
private val currentCheckpointUniqueId = MutableMap[Long, Array[String]]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operatorID -> (partitionID -> uniqueID), is this supposed to mean a map of maps? If so, then why is the type of currentCheckpointUniqueId just a single map?

I also don't fully understand why we would need a unique map for every operator X partition. Why is it not sufficient to have the following protocol, where we have one unique ID for every batch:

For the first batch, an ID is created and sent to all executors. When all tasks finish, that ID is persisted into the commit log. It is also kept in memory for the subsequent batch.

For any other batch, if there does not exist an ID in memory from the previous batch, then it must be read from the commit log and brought into memory. (This is the restart case.)

Then, using the ID in memory from the previous batch (call that prevId), this is sent to all executors in the physical plan, as well as a new ID for the current batch (call this currId). Before any processing start, executors must load and use the state for prevId to process the current batch. Then, they can start processing, and they upload their state as <state file name>_currId.<changelog|snapshot>.

What's wrong with that?

Comment on lines 134 to 144
val ret = StatefulOperatorStateInfo(
checkpointLocation,
runId,
statefulOperatorId.getAndIncrement(),
operatorId,
currentBatchId,
numStateStores)
numStateStores,
currentCheckpointUniqueId.get(operatorId))
ret
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ret is not needed

case e: StreamingDeduplicateWithinWatermarkExec =>
assert(e.stateInfo.isDefined)
updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, e.getCheckpointInfo())
// TODO Need to deal with FlatMapGroupsWithStateExec, TransformWithStateExec,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not?

And I also don't see why we need to enumerate all of these here. Can we leverage the StatefulOperator trait and use that to get the state info? It should clean this up quite a bit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will, though, probably have to do some work to make sure that getCheckpointInfo can be called for any stateful operator.

watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan)
val latestExecPlan = execCtx.executionPlan.executedPlan
watermarkTracker.updateWatermark(latestExecPlan)
if (sparkSession.sessionState.conf.stateStoreCheckpointFormatVersion >= 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like the >= 2 sprinkled everywhere. Can you define a constant somewhere, and then have a utility method that you call

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

I'd introduce a StreamingCheckpointProtocolVersion object or something and then add utility methods like:

def supportsStateCheckpointIds

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I crated StatefulOperatorStateInfo.enableStateStoreCheckpointIds() after Neil's comment. This is a left over. Will switch.

val isFirstBatch: Boolean)
val isFirstBatch: Boolean,
val currentCheckpointUniqueId:
MutableMap[Long, Array[String]] = MutableMap[Long, Array[String]]())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always true that partition IDs are always [0, numPartitions)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is true.

})
}

private def updateCheckpointId(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me make sure I understand the flow here:

  1. Micro-batch ends, we call updateCheckpointId
  2. This goes through every stateful operator and calls updateCheckpointIdForOperator
  3. For each operator, we call into its getCheckpointInfo method
    1. That method will access the checkpointInfoAccumulator
    2. The checkpointInfoAccumulator is appended to using the unique ID from the state store after processing all data on the task
  4. In the future, we'll write this to the commit log.

Is this right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. I should write a comment somewhere.

@@ -803,6 +843,14 @@ class RocksDB(
/** Get the write buffer manager and cache */
def getWriteBufferManagerAndCache(): (WriteBufferManager, Cache) = (writeBufferManager, lruCache)

def getLatestCheckpointInfo(partitionId: Int): StateStoreCheckpointInfo = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this ever be called if lastCommittedCheckpointId is None or LastCommitBasedCheckpointId is None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will always be called. The caller has no knowledge on what's going on there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add scaladocs please

Comment on lines 163 to 174
// variables to manage checkpoint ID. Once a checkpoingting finishes, it nees to return
// the `lastCommittedCheckpointId` as the committed checkpointID, as well as
// `LastCommitBasedCheckpointId` as the checkpontID of the previous version that is based on.
// `loadedCheckpointId` is the checkpointID for the current live DB. After the batch finishes
// and checkpoint finishes, it will turn into `LastCommitBasedCheckpointId`.
// `sessionCheckpointId` store an ID to be used for future checkpoints. It is kept being used
// until we have to use a new one. We don't need to reuse any uniqueID, but reusing when possible
// can help debug problems.
@volatile private var LastCommitBasedCheckpointId: Option[String] = None
@volatile private var lastCommittedCheckpointId: Option[String] = None
@volatile private var loadedCheckpointId: Option[String] = None
@volatile private var sessionCheckpointId: Option[String] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never read sessionCheckpointId and the comment doesn't really help me. What is it being used for?

Is there a reason LastCommitBasedCheckpointId is capitalized? And LastCommitBasedCheckpointId isn't even used in this PR since there is another TODO that says // TODO validate baseCheckpointId? Is that right?

Comment on lines 171 to 174
@volatile private var LastCommitBasedCheckpointId: Option[String] = None
@volatile private var lastCommittedCheckpointId: Option[String] = None
@volatile private var loadedCheckpointId: Option[String] = None
@volatile private var sessionCheckpointId: Option[String] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment specifically why these are marked as volatile? From what I can tell, these are only read/written to by the query execution thread.

partitionId: Int,
batchVersion: Long,
checkpointId: Option[String],
baseCheckpointId: Option[String])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call this checkpointId in some places and baseCheckpointId in others? Can you clarify which is which, and what specifically it should be here?

Comment on lines +205 to +251
.map {
case (key, values) => key -> values.head
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This list would be non-zero only if there was a task retry/speculative execution, right?

Copy link
Contributor

@neilramaswamy neilramaswamy Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And as discussed earlier today offline, this has the issue of not working if the same partition has multiple state stores, e.g. in a stream-stream join, which is actually a very serious issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there was a task retry/speculative execution

Also if there is a fan-out in foreachBatch, i.e. df.write.save() executed twice

@siying siying force-pushed the unique_id2 branch 2 times, most recently from 8dab4a8 to efe3ab5 Compare September 23, 2024 23:13
Copy link
Contributor

@neilramaswamy neilramaswamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level ideas look good, nothing super fundamental. Some clarity, testing, and question comments.

@@ -190,6 +190,11 @@ trait StateStore extends ReadStateStore {
/** Current metrics of the state store */
def metrics: StateStoreMetrics

/** Return information on recently generated checkpoints */
def getCheckpointInfo: StateStoreCheckpointInfo = {
StateStoreCheckpointInfo(-1, -1, None, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why default implementation? If all the sub-classes are overriding it, let's just make it required with no default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, it could lead to bugs if this is incorrect, right? I'd remove a default implementation in such a case (it may require changes in tests I guess, but that can be handled with a trait or something)

Comment on lines 244 to 245
// The checkpoint ID for a checkpoint at `batchVersion`. This is used to identify the checkpoint
checkpointId: Option[String],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use a String, we need to mention that it's not necessarily one checkpoint ID. It could be many, comma-separated.

But to be honest, I don't think we should be using String here, because it's ambiguous. Is it 1 checkpoint? 4 checkpoints? You cannot simply tell by looking at the code. The naming is also off in the case of multiple checkpoints; it's StateStore*s*CheckpointInfo.

I think it makes more sense for us to return, all the way through the accumulator, a Seq[String]. Then, the only place that the merging should happen is inside of def getCheckpointInfo inside of StateStoreWriter. This avoids us from awkwardly having one-off merging logic inside of the s/s join, even though I know it's the only place.

execCtx.batchId == -1 || v == execCtx.batchId + 1,
s"version $v doesn't match current Batch ID ${execCtx.batchId}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the assertion here. We say v == batchId + 1 and then assert that v must match batchId?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can rephrase it, but batch n commits to state store version n+1.

@@ -72,7 +72,8 @@ class RocksDB(
localRootDir: File = Utils.createTempDir(),
hadoopConf: Configuration = new Configuration,
loggingId: String = "",
useColumnFamilies: Boolean = false) extends Logging {
useColumnFamilies: Boolean = false,
ifEnableCheckpointId: Boolean = false) extends Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More consistent to call it enableStateStoreCheckpointIds.

I also think that the term "checkpoint ID" is very confusing. The term makes it feel like it's an ID for an entire checkpoint, when really it's an ID for a particular state store that has been checkpointed.

I know it's a tedious modification to make. I would be happy to alleviate some of this work by creating a branch with that change and putting up a PR that you can merge back in this branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a suggestion for a better name? I can definitely change it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

State store checkpoint ID?

@@ -808,6 +824,45 @@ object SymmetricHashJoinStateManager {
result
}

def mergeStateStoreCheckpointInfo(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already commented about this elsewhere (that it shouldn't be in the symmetric hash join state manager), but this was confusing to read because it is used in two places:

  1. To merge the key with index state store with the key with index to value state store
  2. To merge the results from step (1) for both the left and the right into one result

testWithChangelogCheckpointingEnabled(
s"checkpointFormatVersion2 validate ID with dedup and groupBy") {
val providerClassName = classOf[TestStateStoreProviderWrapper].getCanonicalName
TestStateStoreWrapper.clear()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these can be refactored into a beforeEach in the class

(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
(SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
val checkpointDir = Utils.createTempDir().getCanonicalFile
checkpointDir.delete()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to delete this? And why not use withTempDir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I'll do it. I copy&pasted from a previous test without thinking.

@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
}
}

testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell, none of the new suites that were added cover the edge case in the design doc, right? There's no speculative execution here.

I think what you could do is create new manual StateStores that simulate the race here, without needing to write a query that does this. Right?

assert(checkpointInfoList.size == 12)
checkpointInfoList.foreach { l =>
assert(l.checkpointId.isDefined)
if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't follow this. Why are we just checking these specific batchVersions? Shouldn't all of them, 0 to 5 inclusive, be present?

Comment on lines 482 to 489
for {
a <- checkpointInfoList
b <- checkpointInfoList
if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1
} {
// if batch version exists, it should be the same as the checkpoint ID of the previous batch
assert(!a.baseCheckpointId.isDefined || b.checkpointId == a.baseCheckpointId)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can definitely be refactored; you're using the same code snippet in all tests? Seems like a StateStoreCheckpointIdTestUtils could be good.

@siying siying changed the title [SPARK-49411][SS] Communicate CheckpointID between driver and stateful operators [SPARK-49411][SS] Communicate State Store Checkpoint ID between driver and stateful operators Sep 30, 2024
Copy link
Contributor

@neilramaswamy neilramaswamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming we have much stronger testing in the future, I'd be ok merging this.

Comment on lines 321 to 325
stateStoreCkptIds(0), stateStoreCkptIds(1), skippedNullValueCount)
val rightSideJoiner = new OneSideHashJoiner(
RightSide, right.output, rightKeys, rightInputIter,
condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right, partitionId,
skippedNullValueCount)
stateStoreCkptIds(2), stateStoreCkptIds(3), skippedNullValueCount)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fragile, but I guess it's not a merge blocker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can definitely first deserialize to a case class, if you think that's better. Or do you think we should serialize into the checkpointInfo itself? I feel like it might be over-engineering, considering that the long term direction is probably to merge the 4 state stores into one.

// we have to use a new one. We have to update `sessionStateStoreCkptId` if we reload a previous
// batch version, because we have to use a new checkpointID for re-committing a version.
// The reusing is to help debugging but is not required for the algorithm to work.
private var LastCommitBasedStateStoreCkptId: Option[String] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand why this is capitalized. I think we also ought to write down the threading model here. Who can read these? Who can write them? If there is concurrent access, what synchronizes access?

Also this comment has several typos, e.g. checkpoingting and nees, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I'll fix them.

Copy link
Contributor

@brkyvz brkyvz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left mostly small comments. I would recommend leaving out the accumulator usage from this PR as that the correctness of that code path is somewhat dubious

@@ -243,6 +243,7 @@ trait FlatMapGroupsWithStateExecBase
stateManager.stateSchema,
NoPrefixKeyStateEncoderSpec(groupingAttributes.toStructType),
stateInfo.get.storeVersion,
stateInfo.get.getStateStoreCkptId(partitionId).map(_(0)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uber nit: this would seem to magical for new Scala learners. Can we write this out as _.get(0) if this is an array or Seq?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a newbie in Scala. I checked but there is no get() in Scala array. three is .apply() but it is even more confusing to me. I'll replace those (0) with .head. for the J&J case, I think (0), (1), etc looks OK.

@@ -105,7 +105,7 @@ class StreamStreamJoinStatePartitionReader(
val stateInfo = StatefulOperatorStateInfo(
partition.sourceOptions.stateCheckpointLocation.toString,
partition.queryId, partition.sourceOptions.operatorId,
partition.sourceOptions.batchId + 1, -1)
partition.sourceOptions.batchId + 1, -1, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we load the previous state store correctly in this case then in a stream restart?

currentBatchId,
numStateStores)
numStateStores,
currentStateStoreCkptId.get(operatorId))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there any assertions we can add that this isn't empty for a batch after version 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion will be more straight-forward after we add the support to persistent the ID to commit logs. For now, it is also empty when the query is just started. I can leave a comment here, saying we should add an assertion once only batch 0 can be empty.

@@ -900,12 +906,46 @@ class MicroBatchExecution(
*/
protected def markMicroBatchExecutionStart(execCtx: MicroBatchExecutionContext): Unit = {}

private def updateStateStoreCkptIdForOperator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add scala docs for the methods below?

Comment on lines 920 to 923
currentStateStoreCkptId.put(opId, checkpointInfo.map { c =>
assert(c.stateStoreCkptId.isDefined)
c.stateStoreCkptId.get
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: inlining like the put side makes the code a bit harder to read. Can you please move this out into a variable?

Comment on lines 850 to 855
joinCkptInfo.left.keyToNumValues.stateStoreCkptId.map(
Array(
_,
joinCkptInfo.left.valueToNumKeys.stateStoreCkptId.get,
joinCkptInfo.right.keyToNumValues.stateStoreCkptId.get,
joinCkptInfo.right.valueToNumKeys.stateStoreCkptId.get)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you not inline these please?

Comment on lines 841 to 846
assert(
joinCkptInfo.left.keyToNumValues.partitionId == joinCkptInfo.right.keyToNumValues.partitionId)
assert(joinCkptInfo.left.keyToNumValues.batchVersion ==
joinCkptInfo.right.keyToNumValues.batchVersion)
assert(joinCkptInfo.left.keyToNumValues.stateStoreCkptId.isDefined ==
joinCkptInfo.right.keyToNumValues.stateStoreCkptId.isDefined)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

messages for the assertions please

Comment on lines 867 to 870
// Stream-stream join has 4 state stores instead of one. So it will generate 4 different
// checkpoint IDs. They are translated from each joiners' state store into an array
// through mergeStateStoreCheckpointInfo(). This function is used to read it back into
// individual state store checkpoint IDs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make it the method scaladoc please

Comment on lines +228 to +234
/**
* Aggregator used for the executors to pass new state store checkpoints' IDs to driver.
* For the general checkpoint ID workflow, see comments of
* class class [[StatefulOperatorStateInfo]]
*/
val checkpointInfoAccumulator: CollectionAccumulator[StatefulOpStateStoreCheckpointInfo] = {
SparkContext.getActive.map(_.collectionAccumulator[StatefulOpStateStoreCheckpointInfo]).get
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't use accumulators for this but prefer an RPC channel. Accumulators can cause some havoc with failed or speculative tasks. Is it possible to remove this part from this PR and have that be a separate PR?

Comment on lines 584 to 585
keyToNumValuesStateStoreCkptId,
keyWithIndexToValueStateStoreCkptId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mind using parameter names for these two to prevent accidental ordering issues?

Copy link
Contributor

@brkyvz brkyvz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to leave quick feedback. Still halfway through of my pass

Comment on lines 934 to 935
* @param execCtx
* @param latestExecPlan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove these

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@siying missed this comment

latestExecPlan: SparkPlan): Unit = {
latestExecPlan.collect {
case e: StateStoreWriter =>
assert(e.stateInfo.isDefined)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you forget addressing this?

watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan)
val latestExecPlan = execCtx.executionPlan.executedPlan
watermarkTracker.updateWatermark(latestExecPlan)
if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(sparkSession.sessionState.conf)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you be using sparkSessionForStream here? Otherwise this can change from microbatch to microbatch, which is risky

case class StateStoreCheckpointInfo(
partitionId: Int,
batchVersion: Long,
// The checkpoint ID for a checkpoint at `batchVersion`. This is used to identify the checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move these above to @param lines in the scaladoc?

@@ -90,6 +91,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
val inputIter = dataRDD.iterator(partition, ctxt)
val store = StateStore.getReadOnly(
storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion,
stateStoreCkptIds.map(_(partition.index).head),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _.apply(...).head

@@ -126,6 +129,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
val inputIter = dataRDD.iterator(partition, ctxt)
val store = StateStore.get(
storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion,
uniqueId.map(_(partition.index).head),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

@brkyvz brkyvz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment on lines 934 to 935
* @param execCtx
* @param latestExecPlan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@siying missed this comment

private val sparkSessionForStream = sparkSession.cloneSession()
protected val sparkSessionForStream = sparkSession.cloneSession()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the future - I feel like we should refactor these abstractions a bit to ensure that developers cannot make the same wrong usage of session mistakes again. Today it's too subtle and easy to hit

}

object StatefulOperatorStateInfo {
def enableStateStoreCheckpointIds(conf: SQLConf): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs please

@HeartSaVioR
Copy link
Contributor

e4863d6 has passed CI and I don't think 1c380ee would break the CI in any way - even CI for 1c380ee passed the streaming tests.

@HeartSaVioR
Copy link
Contributor

HeartSaVioR commented Oct 18, 2024

I'm merging the PR on behalf of @brkyvz as he asked personally. It's also approved by two more contributors (my team) so I feel OK to merge this. Just to leave DISCLAIMER.

@HeartSaVioR
Copy link
Contributor

Thanks! Merging to master.

@siying
Copy link
Contributor Author

siying commented Oct 18, 2024

@HeartSaVioR thank you for your help!

himadripal pushed a commit to himadripal/spark that referenced this pull request Oct 19, 2024
…r and stateful operators

### What changes were proposed in this pull request?

This is an incremental step to implement RocksDB state store checkpoint format V2.

Once conf STATE_STORE_CHECKPOINT_FORMAT_VERSION is set to be higher than version 2, the executor returns checkpointID to the driver (only done for RocksDB). The driver stores is locally. For the next batch, the State Store Checkpoint ID is sent to the executor to be used to load the state store. If the local version of the executor doesn't match the uniqueID, it will reload from the checkpoint.

There is no behavior change if the default checkpoint format is used.

### Why are the changes needed?

This is an incremental step of the project of a new RocksDB State Store checkpoint format. The new format is to simplify checkpoint mechanism to make it less bug prone, and fix some unexpected query results in rare queries.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
A new unit test is added to cover format version. And another unit test is added to validate the uniqueID is passed back and force as expected.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47895 from siying/unique_id2.

Authored-by: Siying Dong <siying.dong@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants