Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.internal.LogKeys.{CHECKSUM, NUM_BYTES, PATH, TIMEOUT}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager.CancellableFSDataOutputStream
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.Utils

/** Information about the creator of the checksum file. Useful for debugging */
case class ChecksumFileCreatorInfo(
Expand Down Expand Up @@ -500,16 +501,14 @@ class ChecksumCancellableFSDataOutputStream(
@volatile private var closed = false

override def cancel(): Unit = {
val mainFuture = Future {
// Cancel both streams synchronously rather than using futures. If the current thread is
// interrupted and we call this method, scheduling work on futures would immediately throw
// InterruptedException leaving the streams in an inconsistent state.
Utils.tryWithSafeFinally {
mainStream.cancel()
}(uploadThreadPool)

val checksumFuture = Future {
} {
checksumStream.cancel()
}(uploadThreadPool)

awaitResult(mainFuture, Duration.Inf)
awaitResult(checksumFuture, Duration.Inf)
}
}

override def close(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1650,17 +1650,23 @@ class RocksDB(
* Drop uncommitted changes, and roll back to previous version.
*/
def rollback(): Unit = {
numKeysOnWritingVersion = numKeysOnLoadedVersion
numInternalKeysOnWritingVersion = numInternalKeysOnLoadedVersion
loadedVersion = -1L
lastCommitBasedStateStoreCkptId = None
lastCommittedStateStoreCkptId = None
loadedStateStoreCkptId = None
sessionStateStoreCkptId = None
lineageManager.clear()
changelogWriter.foreach(_.abort())
// Make sure changelogWriter gets recreated next time.
changelogWriter = None
logInfo(
log"Rolling back uncommitted changes on version ${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
try {
numKeysOnWritingVersion = numKeysOnLoadedVersion
numInternalKeysOnWritingVersion = numInternalKeysOnLoadedVersion
loadedVersion = -1L
lastCommitBasedStateStoreCkptId = None
lastCommittedStateStoreCkptId = None
loadedStateStoreCkptId = None
sessionStateStoreCkptId = None
lineageManager.clear()
changelogWriter.foreach(_.abort())
} finally {
// Make sure changelogWriter gets recreated next time even if the changelogWriter aborts with
// an exception.
changelogWriter = None
}
logInfo(log"Rolled back to ${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest

implicit def toArray(str: String): Array[Byte] = if (str != null) str.getBytes else null

implicit def toStr(bytes: Array[Byte]): String = if (bytes != null) new String(bytes) else null

def toStr(kv: ByteArrayPair): (String, String) = (toStr(kv.key), toStr(kv.value))

case class FailureConf(ifEnableStateStoreCheckpointIds: Boolean, fileType: String) {
override def toString: String = {
s"ifEnableStateStoreCheckpointIds = $ifEnableStateStoreCheckpointIds, " +
Expand Down Expand Up @@ -824,6 +828,62 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest
}
}

/**
* Test that verifies that when a task is interrupted, the store's rollback() method does not
* throw an exception and the store can still be used after the rollback.
*/
test("SPARK-54585: Interrupted task calling rollback does not throw an exception") {
val hadoopConf = new Configuration()
hadoopConf.set(
STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key,
fileManagerClassName
)
withTempDirAllowFailureInjection { (remoteDir, _) =>
val sqlConf = new SQLConf()
sqlConf.setConfString("spark.sql.streaming.checkpoint.fileChecksum.enabled", "true")
val rocksdbChangelogCheckpointingConfKey =
RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled"
sqlConf.setConfString(rocksdbChangelogCheckpointingConfKey, "true")
val conf = RocksDBConf(StateStoreConf(sqlConf))

withDB(
remoteDir.getAbsolutePath,
version = 0,
conf = conf,
hadoopConf = hadoopConf
) { db =>
db.put("key0", "value0")
val checkpointId1 = commitAndGetCheckpointId(db)

db.load(1, checkpointId1)
db.put("key1", "value1")
val checkpointId2 = commitAndGetCheckpointId(db)

db.load(2, checkpointId2)
db.put("key2", "value2")

// Simulate what happens when a task is killed, the thread's interrupt flag is set.
// This replicates the scenario where TaskContext.markTaskFailed() is called and
// the task failure listener invokes RocksDBStateStore.abort() -> rollback().
Thread.currentThread().interrupt()

// rollback() should not throw an exception
db.rollback()

// Clear the interrupt flag for subsequent operations
Thread.interrupted()

// Reload the store and insert a new value
db.load(2, checkpointId2)
db.put("key3", "value3")

// Verify the store has the correct values
assert(db.iterator().map(toStr).toSet ===
Set(("key0", "value0"), ("key1", "value1"), ("key3", "value3")))
}
}
}

def commitAndGetCheckpointId(db: RocksDB): Option[String] = {
val (v, ci) = db.commit()
ci.stateStoreCkptId
Expand Down