diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index c2087ec219e57..7cb86dc143844 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -665,6 +665,37 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] checkInvalidVersion(3) } + test("two concurrent StateStores - one for read-only and one for read-write") { + // During Streaming Aggregation, we have two StateStores per task, one used as read-only in + // `StateStoreRestoreExec`, and one read-write used in `StateStoreSaveExec`. `StateStore.abort` + // will be called for these StateStores if they haven't committed their results. We need to + // make sure that `abort` in read-only store after a `commit` in the read-write store doesn't + // accidentally lead to the deletion of state. + val dir = newDir() + val storeId = StateStoreId(dir, 0L, 1) + val provider0 = newStoreProvider(storeId) + // prime state + val store = provider0.getStore(0) + val key = "a" + put(store, key, 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set(key -> 1)) + + // two state stores + val provider1 = newStoreProvider(storeId) + val restoreStore = provider1.getStore(1) + val saveStore = provider1.getStore(1) + + put(saveStore, key, get(restoreStore, key).get + 1) + saveStore.commit() + restoreStore.abort() + + // check that state is correct for next batch + val provider2 = newStoreProvider(storeId) + val finalStore = provider2.getStore(2) + assert(rowsToSet(finalStore.iterator()) === Set(key -> 2)) + } + /** Return a new provider with a random id */ def newStoreProvider(): ProviderClass