diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala index 637d11ad890bf..2de429ee10763 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala @@ -39,6 +39,7 @@ import org.apache.spark.internal.LogKeys.{CHECKSUM, NUM_BYTES, PATH, TIMEOUT} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.Utils /** Information about the creator of the checksum file. Useful for debugging */ case class ChecksumFileCreatorInfo( @@ -500,16 +501,14 @@ class ChecksumCancellableFSDataOutputStream( @volatile private var closed = false override def cancel(): Unit = { - val mainFuture = Future { + // Cancel both streams synchronously rather than using futures. If the current thread is + // interrupted and we call this method, scheduling work on futures would immediately throw + // InterruptedException leaving the streams in an inconsistent state. + Utils.tryWithSafeFinally { mainStream.cancel() - }(uploadThreadPool) - - val checksumFuture = Future { + } { checksumStream.cancel() - }(uploadThreadPool) - - awaitResult(mainFuture, Duration.Inf) - awaitResult(checksumFuture, Duration.Inf) + } } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 8a2ed6d9a529d..c92c5017cada9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1650,17 +1650,23 @@ class RocksDB( * Drop uncommitted changes, and roll back to previous version. */ def rollback(): Unit = { - numKeysOnWritingVersion = numKeysOnLoadedVersion - numInternalKeysOnWritingVersion = numInternalKeysOnLoadedVersion - loadedVersion = -1L - lastCommitBasedStateStoreCkptId = None - lastCommittedStateStoreCkptId = None - loadedStateStoreCkptId = None - sessionStateStoreCkptId = None - lineageManager.clear() - changelogWriter.foreach(_.abort()) - // Make sure changelogWriter gets recreated next time. - changelogWriter = None + logInfo( + log"Rolling back uncommitted changes on version ${MDC(LogKeys.VERSION_NUM, loadedVersion)}") + try { + numKeysOnWritingVersion = numKeysOnLoadedVersion + numInternalKeysOnWritingVersion = numInternalKeysOnLoadedVersion + loadedVersion = -1L + lastCommitBasedStateStoreCkptId = None + lastCommittedStateStoreCkptId = None + loadedStateStoreCkptId = None + sessionStateStoreCkptId = None + lineageManager.clear() + changelogWriter.foreach(_.abort()) + } finally { + // Make sure changelogWriter gets recreated next time even if the changelogWriter aborts with + // an exception. + changelogWriter = None + } logInfo(log"Rolled back to ${MDC(LogKeys.VERSION_NUM, loadedVersion)}") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala index 0b9690ee72775..c48b492e27c27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala @@ -70,6 +70,10 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest implicit def toArray(str: String): Array[Byte] = if (str != null) str.getBytes else null + implicit def toStr(bytes: Array[Byte]): String = if (bytes != null) new String(bytes) else null + + def toStr(kv: ByteArrayPair): (String, String) = (toStr(kv.key), toStr(kv.value)) + case class FailureConf(ifEnableStateStoreCheckpointIds: Boolean, fileType: String) { override def toString: String = { s"ifEnableStateStoreCheckpointIds = $ifEnableStateStoreCheckpointIds, " + @@ -824,6 +828,62 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest } } + /** + * Test that verifies that when a task is interrupted, the store's rollback() method does not + * throw an exception and the store can still be used after the rollback. + */ + test("SPARK-54585: Interrupted task calling rollback does not throw an exception") { + val hadoopConf = new Configuration() + hadoopConf.set( + STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, + fileManagerClassName + ) + withTempDirAllowFailureInjection { (remoteDir, _) => + val sqlConf = new SQLConf() + sqlConf.setConfString("spark.sql.streaming.checkpoint.fileChecksum.enabled", "true") + val rocksdbChangelogCheckpointingConfKey = + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" + sqlConf.setConfString(rocksdbChangelogCheckpointingConfKey, "true") + val conf = RocksDBConf(StateStoreConf(sqlConf)) + + withDB( + remoteDir.getAbsolutePath, + version = 0, + conf = conf, + hadoopConf = hadoopConf + ) { db => + db.put("key0", "value0") + val checkpointId1 = commitAndGetCheckpointId(db) + + db.load(1, checkpointId1) + db.put("key1", "value1") + val checkpointId2 = commitAndGetCheckpointId(db) + + db.load(2, checkpointId2) + db.put("key2", "value2") + + // Simulate what happens when a task is killed, the thread's interrupt flag is set. + // This replicates the scenario where TaskContext.markTaskFailed() is called and + // the task failure listener invokes RocksDBStateStore.abort() -> rollback(). + Thread.currentThread().interrupt() + + // rollback() should not throw an exception + db.rollback() + + // Clear the interrupt flag for subsequent operations + Thread.interrupted() + + // Reload the store and insert a new value + db.load(2, checkpointId2) + db.put("key3", "value3") + + // Verify the store has the correct values + assert(db.iterator().map(toStr).toSet === + Set(("key0", "value0"), ("key1", "value1"), ("key3", "value3"))) + } + } + } + def commitAndGetCheckpointId(db: RocksDB): Option[String] = { val (v, ci) = db.commit() ci.stateStoreCkptId