Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,8 @@ object SQLConf {
"The class used to manage state data in stateful streaming queries. This class must " +
"be a subclass of StateStoreProvider, and must have a zero-arg constructor.")
.stringConf
.createOptional
.createWithDefault(
"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider")

val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
Expand Down Expand Up @@ -897,7 +898,7 @@ class SQLConf extends Serializable with Logging {

def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD)

def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS)
def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS)

def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ package org.apache.spark.sql.execution.streaming
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS}

/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
Expand Down Expand Up @@ -78,7 +82,40 @@ case class OffsetSeqMetadata(
def json: String = Serialization.write(this)(OffsetSeqMetadata.format)
}

object OffsetSeqMetadata {
object OffsetSeqMetadata extends Logging {
private implicit val format = Serialization.formats(NoTypeHints)
private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS)

def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)

def apply(
batchWatermarkMs: Long,
batchTimestampMs: Long,
sessionConf: RuntimeConfig): OffsetSeqMetadata = {
val confs = relevantSQLConfs.map { conf => conf.key -> sessionConf.get(conf.key) }.toMap
OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs)
}

/** Set the SparkSession configuration with the values in the metadata */
def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: RuntimeConfig): Unit = {
OffsetSeqMetadata.relevantSQLConfs.map(_.key).foreach { confKey =>

metadata.conf.get(confKey) match {

case Some(valueInMetadata) =>
// Config value exists in the metadata, update the session config with this value
val optionalValueInSession = sessionConf.getOption(confKey)
if (optionalValueInSession.isDefined && optionalValueInSession.get != valueInMetadata) {
logWarning(s"Updating the value of conf '$confKey' in current session from " +
s"'${optionalValueInSession.get}' to '$valueInMetadata'.")
}
sessionConf.set(confKey, valueInMetadata)

case None =>
// For backward compatibility, if a config was not recorded in the offset log,
// then log it, and let the existing conf value in SparkSession prevail.
logWarning (s"Conf '$confKey' was not found in the offset log, using existing value")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,8 @@ class StreamExecution(
}

/** Metadata associated with the offset seq of a batch in the query. */
protected var offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0,
conf = Map(SQLConf.SHUFFLE_PARTITIONS.key ->
sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString))
protected var offsetSeqMetadata = OffsetSeqMetadata(
batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf)

override val id: UUID = UUID.fromString(streamMetadata.id)

Expand Down Expand Up @@ -285,9 +284,8 @@ class StreamExecution(
val sparkSessionToRunBatches = sparkSession.cloneSession()
// Adaptive execution can change num shuffle partitions, disallow
sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0,
conf = Map(SQLConf.SHUFFLE_PARTITIONS.key ->
sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)))
offsetSeqMetadata = OffsetSeqMetadata(
batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionToRunBatches.conf)

if (state.compareAndSet(INITIALIZING, ACTIVE)) {
// Unblock `awaitInitialization`
Expand Down Expand Up @@ -441,21 +439,9 @@ class StreamExecution(

// update offset metadata
nextOffsets.metadata.foreach { metadata =>
val shufflePartitionsSparkSession: Int =
sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS)
val shufflePartitionsToUse = metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, {
// For backward compatibility, if # partitions was not recorded in the offset log,
// then ensure it is not missing. The new value is picked up from the conf.
logWarning("Number of shuffle partitions from previous run not found in checkpoint. "
+ s"Using the value from the conf, $shufflePartitionsSparkSession partitions.")
shufflePartitionsSparkSession
})
OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf)
offsetSeqMetadata = OffsetSeqMetadata(
metadata.batchWatermarkMs, metadata.batchTimestampMs,
metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsToUse.toString))
// Update conf with correct number of shuffle partitions
sparkSessionToRunBatches.conf.set(
SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString)
metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf)
}

/* identify the current batch id: if commit log indicates we successfully processed the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ object StateStoreProvider {
indexOrdinal: Option[Int], // for sorting the data
storeConf: StateStoreConf,
hadoopConf: Configuration): StateStoreProvider = {
val providerClass = storeConf.providerClass.map(Utils.classForName)
.getOrElse(classOf[HDFSBackedStateStoreProvider])
val providerClass = Utils.classForName(storeConf.providerClass)
val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider]
provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class StateStoreConf(@transient private val sqlConf: SQLConf)
* Optional fully qualified name of the subclass of [[StateStoreProvider]]
* managing state data. That is, the implementation of the State Store to use.
*/
val providerClass: Option[String] = sqlConf.stateStoreProviderClass
val providerClass: String = sqlConf.stateStoreProviderClass

/**
* Additional configurations related to state store. This will capture all configs in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,18 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext {
}

// None set
assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}"""))
assert(new OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}"""))

// One set
assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
assert(new OffsetSeqMetadata(1, 0, Map.empty) ===
OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
assert(new OffsetSeqMetadata(0, 2, Map.empty) ===
OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) ===
OffsetSeqMetadata(s"""{"conf": {"$key":2}}"""))

// Two set
assert(OffsetSeqMetadata(1, 2, Map.empty) ===
assert(new OffsetSeqMetadata(1, 2, Map.empty) ===
OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) ===
OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}"""))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,19 +637,11 @@ class StreamSuite extends StreamTest {
}

testQuietly("specify custom state store provider") {
val queryName = "memStream"
val providerClassName = classOf[TestStateStoreProvider].getCanonicalName
withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) {
val input = MemoryStream[Int]
val query = input
.toDS()
.groupBy()
.count()
.writeStream
.outputMode("complete")
.format("memory")
.queryName(queryName)
.start()
val df = input.toDS().groupBy().count()
val query = df.writeStream.outputMode("complete").format("memory").queryName("name").start()
input.addData(1, 2, 3)
val e = intercept[Exception] {
query.awaitTermination()
Expand All @@ -659,6 +651,45 @@ class StreamSuite extends StreamTest {
assert(e.getMessage.contains("instantiated"))
}
}

testQuietly("custom state store provider read from offset log") {
val input = MemoryStream[Int]
val df = input.toDS().groupBy().count()
val providerConf1 = "spark.sql.streaming.stateStore.providerClass" ->
"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"
val providerConf2 = "spark.sql.streaming.stateStore.providerClass" ->
classOf[TestStateStoreProvider].getCanonicalName

def runQuery(queryName: String, checkpointLoc: String): Unit = {
val query = df.writeStream
.outputMode("complete")
.format("memory")
.queryName(queryName)
.option("checkpointLocation", checkpointLoc)
.start()
input.addData(1, 2, 3)
query.processAllAvailable()
query.stop()
}

withTempDir { dir =>
val checkpointLoc1 = new File(dir, "1").getCanonicalPath
withSQLConf(providerConf1) {
runQuery("query1", checkpointLoc1) // generate checkpoints
}

val checkpointLoc2 = new File(dir, "2").getCanonicalPath
withSQLConf(providerConf2) {
// Verify new query will use new provider that throw error on loading
intercept[Exception] {
runQuery("query2", checkpointLoc2)
}

// Verify old query from checkpoint will still use old provider
runQuery("query1", checkpointLoc1)
}
}
}
}

abstract class FakeSource extends StreamSourceProvider {
Expand Down