Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionRead
import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1}
import org.apache.spark.sql.execution.streaming.state.{OperatorInfoV1, OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, StateStoreMetadataV1}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand All @@ -46,6 +46,7 @@ case class StateMetadataTableEntry(
numPartitions: Int,
minBatchId: Long,
maxBatchId: Long,
operatorPropertiesJson: String,
numColsPrefixKey: Int) {
def toRow(): InternalRow = {
new GenericInternalRow(
Expand All @@ -55,6 +56,7 @@ case class StateMetadataTableEntry(
numPartitions,
minBatchId,
maxBatchId,
UTF8String.fromString(operatorPropertiesJson),
numColsPrefixKey))
}
}
Expand All @@ -68,6 +70,7 @@ object StateMetadataTableEntry {
.add("numPartitions", IntegerType)
.add("minBatchId", LongType)
.add("maxBatchId", LongType)
.add("operatorProperties", StringType)
}
}

Expand Down Expand Up @@ -188,29 +191,59 @@ class StateMetadataPartitionReader(
} else Array.empty
}

private def allOperatorStateMetadata: Array[OperatorStateMetadata] = {
// Need this to be accessible from IncrementalExecution for the planning rule.
private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = {
val stateDir = new Path(checkpointLocation, "state")
val opIds = fileManager
.list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted
opIds.map { opId =>
new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read()
val operatorIdPath = new Path(stateDir, opId.toString)
// check if OperatorStateMetadataV2 path exists, if it does, read it
// otherwise, fall back to OperatorStateMetadataV1
val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataDirPath(operatorIdPath)
val operatorStateMetadataVersion = if (fileManager.exists(operatorStateMetadataV2Path)) {
2
} else {
1
}
OperatorStateMetadataReader.createReader(
operatorIdPath, hadoopConf, operatorStateMetadataVersion).read() match {
case Some(metadata) => metadata
case None => OperatorStateMetadataV1(OperatorInfoV1(opId, null),
Array(StateStoreMetadataV1(null, -1, -1)))
}
}
}

private[sql] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = {
allOperatorStateMetadata.flatMap { operatorStateMetadata =>
require(operatorStateMetadata.version == 1)
val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1]
operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata =>
StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId,
operatorStateMetadataV1.operatorInfo.operatorName,
stateStoreMetadata.storeName,
stateStoreMetadata.numPartitions,
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
stateStoreMetadata.numColsPrefixKey
)
require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2)
operatorStateMetadata match {
case v1: OperatorStateMetadataV1 =>
v1.stateStoreInfo.map { stateStoreMetadata =>
StateMetadataTableEntry(v1.operatorInfo.operatorId,
v1.operatorInfo.operatorName,
stateStoreMetadata.storeName,
stateStoreMetadata.numPartitions,
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
null,
stateStoreMetadata.numColsPrefixKey
)
}
case v2: OperatorStateMetadataV2 =>
v2.stateStoreInfo.map { stateStoreMetadata =>
StateMetadataTableEntry(v2.operatorInfo.operatorId,
v2.operatorInfo.operatorName,
stateStoreMetadata.storeName,
stateStoreMetadata.numPartitions,
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
v2.operatorPropertiesJson,
-1 // numColsPrefixKey is not available in OperatorStateMetadataV2
)
}
}
}
}
}.iterator
}.iterator
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataWriter
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand Down Expand Up @@ -208,13 +208,16 @@ class IncrementalExecution(
}
val schemaValidationResult = statefulOp.
validateAndMaybeEvolveStateSchema(hadoopConf, currentBatchId, stateSchemaVersion)
val stateSchemaPaths = schemaValidationResult.map(_.schemaPath)
// write out the state schema paths to the metadata file
statefulOp match {
case stateStoreWriter: StateStoreWriter =>
val metadata = stateStoreWriter.operatorStateMetadata()
// TODO: [SPARK-48849] Populate metadata with stateSchemaPaths if metadata version is v2
val metadataWriter = new OperatorStateMetadataWriter(new Path(
checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf)
case ssw: StateStoreWriter =>
val metadata = ssw.operatorStateMetadata(stateSchemaPaths)
val metadataWriter = OperatorStateMetadataWriter.createWriter(
new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString),
hadoopConf,
ssw.operatorStateMetadataVersion,
Some(currentBatchId))
metadataWriter.write(metadata)
case _ =>
}
Expand Down Expand Up @@ -456,8 +459,12 @@ class IncrementalExecution(
val reader = new StateMetadataPartitionReader(
new Path(checkpointLocation).getParent.toString,
new SerializableConfiguration(hadoopConf))
ret = reader.stateMetadata.map { metadataTableEntry =>
metadataTableEntry.operatorId -> metadataTableEntry.operatorName
val opMetadataList = reader.allOperatorStateMetadata
ret = opMetadataList.map {
case OperatorStateMetadataV1(operatorInfo, _) =>
operatorInfo.operatorId -> operatorInfo.operatorName
case OperatorStateMetadataV2(operatorInfo, _, _) =>
operatorInfo.operatorId -> operatorInfo.operatorName
}.toMap
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,12 @@ case class StreamingSymmetricHashJoinExec(
private val stateStoreNames =
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)

override def operatorStateMetadata(): OperatorStateMetadata = {
override def operatorStateMetadata(
stateSchemaPaths: List[String] = List.empty): OperatorStateMetadata = {
val info = getStateInfo
val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
val stateStoreInfo = stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray
val stateStoreInfo =
stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray
OperatorStateMetadataV1(operatorInfo, stateStoreInfo)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import java.util.concurrent.TimeUnit.NANOSECONDS

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.JString
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -96,6 +100,8 @@ case class TransformWithStateExec(
}
}

override def operatorStateMetadataVersion: Int = 2

/**
* We initialize this processor handle in the driver to run the init function
* and fetch the schemas of the state variables initialized in this processor.
Expand Down Expand Up @@ -382,12 +388,47 @@ case class TransformWithStateExec(
batchId: Long,
stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
assert(stateSchemaVersion >= 3)
val newColumnFamilySchemas = getColFamilySchemas()
val newSchemas = getColFamilySchemas()
val stateSchemaDir = stateSchemaDirPath()
val stateSchemaFilePath = new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
newColumnFamilySchemas.values.toList, session.sessionState, stateSchemaVersion,
schemaFilePath = Some(stateSchemaFilePath)))
val newStateSchemaFilePath =
new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")
val metadataPath = new Path(getStateInfo.checkpointLocation, s"${getStateInfo.operatorId}")
val metadataReader = OperatorStateMetadataReader.createReader(
metadataPath, hadoopConf, operatorStateMetadataVersion)
val operatorStateMetadata = metadataReader.read()
val oldStateSchemaFilePath: Option[Path] = operatorStateMetadata match {
case Some(metadata) =>
metadata match {
case v2: OperatorStateMetadataV2 =>
Some(new Path(v2.stateStoreInfo.head.stateSchemaFilePath))
case _ => None
}
case None => None
}
List(StateSchemaCompatibilityChecker.
validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
newSchemas.values.toList, session.sessionState, stateSchemaVersion,
storeName = StateStoreId.DEFAULT_STORE_NAME,
oldSchemaFilePath = oldStateSchemaFilePath,
newSchemaFilePath = Some(newStateSchemaFilePath)))
}

/** Metadata of this stateful operator and its states stores. */
override def operatorStateMetadata(
stateSchemaPaths: List[String]): OperatorStateMetadata = {
val info = getStateInfo
val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
// stateSchemaFilePath should be populated at this point
val stateStoreInfo =
Array(StateStoreMetadataV2(
StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, stateSchemaPaths.head))

val operatorPropertiesJson: JValue =
("timeMode" -> JString(timeMode.toString)) ~
("outputMode" -> JString(outputMode.toString))

val json = compact(render(operatorPropertiesJson))
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json)
}

private def stateSchemaDirPath(): Path = {
Expand Down
Loading