Skip to content

[SPARK-51823][SS] Add config to not persist state store on executors #50612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -2519,6 +2519,18 @@ object SQLConf {
.stringConf
.createWithDefault(CompressionCodec.LZ4)

val STATE_STORE_UNLOAD_ON_COMMIT =
buildConf("spark.sql.streaming.stateStore.unloadOnCommit")
.internal()
.doc("When true, Spark will synchronously run maintenance and then close each StateStore " +
"instance on task completion. This removes the overhead of keeping every StateStore " +
"loaded indefinitely, at the cost of having to reload each StateStore every batch. " +
"Stateful applications that are failing due to resource exhaustion or that use " +
"dynamic allocation may benefit from enabling this.")
.version("4.1.0")
.booleanConf
.createWithDefault(false)

val CHECKPOINT_RENAMEDFILE_CHECK_ENABLED =
buildConf("spark.sql.streaming.checkpoint.renamedFileCheck.enabled")
.doc("When true, Spark will validate if renamed checkpoint file exists.")
@@ -6150,6 +6162,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY)

def stateStoreUnloadOnCommit: Boolean = getConf(STATE_STORE_UNLOAD_ON_COMMIT)

def streamingMaintenanceInterval: Long = getConf(STREAMING_MAINTENANCE_INTERVAL)

def stateStoreCompressionCodec: String = getConf(STATE_STORE_COMPRESSION_CODEC)
Original file line number Diff line number Diff line change
@@ -1007,13 +1007,29 @@ object StateStore extends Logging {
log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, storeProviderId.queryRunId)}")
}

val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
providerIdsToUnload.foreach(unload(_))
// Only tell the state store coordinator we are active if we will remain active
// after the task. When we unload after committing, there's no need for the coordinator
// to track which executor has which provider
if (!storeConf.unloadOnCommit) {
val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
providerIdsToUnload.foreach(unload(_))
}

provider
}
}

/** Runs maintenance and then unload a state store provider */
def doMaintenanceAndUnload(storeProviderId: StateStoreProviderId): Unit = {
loadedProviders.synchronized {
loadedProviders.remove(storeProviderId)
}.foreach { provider =>
provider.doMaintenance()
provider.close()
}
}

/** Unload a state store provider */
def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized {
loadedProviders.remove(storeProviderId).foreach(_.close())
@@ -1072,7 +1088,7 @@ object StateStore extends Logging {
val numMaintenanceThreads = storeConf.numStateStoreMaintenanceThreads
val maintenanceShutdownTimeout = storeConf.stateStoreMaintenanceShutdownTimeout
loadedProviders.synchronized {
if (SparkEnv.get != null && !isMaintenanceRunning) {
if (SparkEnv.get != null && !isMaintenanceRunning && !storeConf.unloadOnCommit) {
maintenanceTask = new MaintenanceTask(
storeConf.maintenanceInterval,
task = { doMaintenance() }
Original file line number Diff line number Diff line change
@@ -103,6 +103,9 @@ class StateStoreConf(
val reportSnapshotUploadLag: Boolean =
sqlConf.stateStoreCoordinatorReportSnapshotUploadLag

/** Whether to unload the store on task completion. */
val unloadOnCommit = sqlConf.stateStoreUnloadOnCommit

/**
* Additional configurations related to state store. This will capture all configs in
* SQLConf that start with `spark.sql.streaming.stateStore.`
Original file line number Diff line number Diff line change
@@ -136,6 +136,13 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
stateSchemaBroadcast,
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value,
useMultipleValuesPerKey)

if (storeConf.unloadOnCommit) {
ctxt.addTaskCompletionListener[Unit](_ => {
StateStore.doMaintenanceAndUnload(storeProviderId)
})
}

storeUpdateFunction(store, inputIter)
}
}
Original file line number Diff line number Diff line change
@@ -223,6 +223,49 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
}
}

testWithColumnFamilies("SPARK-51823: unload state stores on commit",
TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled =>
withTempDir { dir =>
withSQLConf(
(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName),
(SQLConf.CHECKPOINT_LOCATION.key -> dir.getCanonicalPath),
(SQLConf.SHUFFLE_PARTITIONS.key -> "1"),
(SQLConf.STATE_STORE_UNLOAD_ON_COMMIT.key -> "true")) {
// Make sure we start with a fresh without any stale state store entries
Utils.clearLocalRootDirs()

val inputData = MemoryStream[Int]

val query = inputData.toDS().toDF("value")
.select($"value")
.groupBy($"value")
.agg(count("*"))
.writeStream
.format("console")
.outputMode("complete")
.start()
try {
inputData.addData(1, 2)
inputData.addData(2, 3)
query.processAllAvailable()

// StateStore should be unloaded, so its tmp dir shouldn't exist
var tmpFiles = new File(Utils.getLocalDir(sparkConf)).listFiles()
assert(tmpFiles.filter(_.getName().startsWith("StateStore")).isEmpty)

inputData.addData(3, 4)
inputData.addData(4, 5)
query.processAllAvailable()

tmpFiles = new File(Utils.getLocalDir(sparkConf)).listFiles()
assert(tmpFiles.filter(_.getName().startsWith("StateStore")).isEmpty)
} finally {
query.stop()
}
}
}
}

testWithChangelogCheckpointingEnabled(
"Streaming aggregation RocksDB State Store backward compatibility.") {
val checkpointDir = Utils.createTempDir().getCanonicalFile
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.tags.ExtendedSQLTest
import org.apache.spark.util.{CompletionIterator, Utils}

@@ -227,6 +228,32 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter {
}
}

test("SPARK-51823: unload on commit") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add an integration test under RocksDBStateStoreIntegrationSuite with the config enabled ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a basic integration test, let me know if there's anything you want to add to it

withSparkSession(
SparkSession.builder()
.config(sparkConf)
.config(SQLConf.STATE_STORE_UNLOAD_ON_COMMIT.key, true)
.getOrCreate()) { spark =>
val path = Utils.createDirectory(tempDir, Random.nextFloat().toString).toString
val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
.mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0),
keySchema, valueSchema,
NoPrefixKeyStateEncoderSpec(keySchema))(increment)

assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))

// Generate next version of stores
val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0)))
.mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 1),
keySchema, valueSchema,
NoPrefixKeyStateEncoderSpec(keySchema))(increment)
assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1))

// Make sure the previous RDD still has the same data.
assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
}
}

private def makeRDD(sc: SparkContext, seq: Seq[(String, Int)]): RDD[(String, Int)] = {
sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2)
}