Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50127] Implement Avro encoding for MapState and PrefixKeyScanStateEncoder #22

Merged
merged 4 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.NextIterator
Expand All @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator
* @param ttlConfig - TTL configuration for values stored in this state
* @param batchTimestampMs - current batch processing timestamp.
* @param metrics - metrics to be updated as part of stateful processing
* @param avroEnc - optional Avro serializer and deserializer for this state variable that
* is used by the StateStore to encode state in Avro format
* @param ttlAvroEnc - optional Avro serializer and deserializer for TTL state that
* is used by the StateStore to encode state in Avro format
* @tparam S - data type of object that will be stored
*/
class ListStateImplWithTTL[S](
Expand All @@ -45,8 +49,10 @@ class ListStateImplWithTTL[S](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs)
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoderSpec] = None,
ttlAvroEnc: Option[AvroEncoderSpec] = None)
extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs, ttlAvroEnc)
with ListStateMetricsImpl
with ListState[S] {

Expand All @@ -65,7 +71,8 @@ class ListStateImplWithTTL[S](
private def initialize(): Unit = {
store.createColFamilyIfAbsent(stateName, keyExprEnc.schema,
getValueSchemaWithTTL(valEncoder.schema, true),
NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = true)
NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = true,
avroEncoderSpec = avroEnc)
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.streaming.MapState
import org.apache.spark.sql.types.StructType

Expand All @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.StructType
* @param keyExprEnc - Spark SQL encoder for key
* @param valEncoder - Spark SQL encoder for value
* @param metrics - metrics to be updated as part of stateful processing
* @param avroEnc - optional Avro serializer and deserializer for this state variable that
* is used by the StateStore to encode state in Avro format
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
*/
Expand All @@ -41,7 +43,8 @@ class MapStateImpl[K, V](
keyExprEnc: ExpressionEncoder[Any],
userKeyEnc: ExpressionEncoder[Any],
valEncoder: ExpressionEncoder[Any],
metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with Logging {
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoderSpec] = None) extends MapState[K, V] with Logging {

// Pack grouping key and user key together as a prefixed composite key
private val schemaForCompositeKeyRow: StructType = {
Expand All @@ -52,7 +55,7 @@ class MapStateImpl[K, V](
keyExprEnc, userKeyEnc, valEncoder, stateName)

store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow, schemaForValueRow,
PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1))
PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1), avroEncoderSpec = avroEnc)

/** Whether state exists or not. */
override def exists(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{MapState, TTLConfig}
import org.apache.spark.util.NextIterator

Expand All @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @param batchTimestampMs - current batch processing timestamp.
* @param metrics - metrics to be updated as part of stateful processing
* @param avroEnc - optional Avro serializer and deserializer for this state variable that
* is used by the StateStore to encode state in Avro format
* @param ttlAvroEnc - optional Avro serializer and deserializer for TTL state that
* is used by the StateStore to encode state in Avro format
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
* @return - instance of MapState of type [K,V] that can be used to store state persistently
Expand All @@ -48,9 +52,11 @@ class MapStateImplWithTTL[K, V](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoderSpec] = None,
ttlAvroEnc: Option[AvroEncoderSpec] = None)
extends CompositeKeyTTLStateImpl[K](stateName, store,
keyExprEnc, userKeyEnc, batchTimestampMs)
keyExprEnc, userKeyEnc, batchTimestampMs, ttlAvroEnc)
with MapState[K, V] with Logging {

private val stateTypesEncoder = new CompositeKeyStateEncoder(
Expand All @@ -66,7 +72,8 @@ class MapStateImplWithTTL[K, V](
getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema)
store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow,
getValueSchemaWithTTL(valEncoder.schema, true),
PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1))
PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1),
avroEncoderSpec = avroEnc)
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,45 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema}
import org.apache.spark.sql.execution.streaming.state.AvroEncoderSpec
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema}
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StructField, StructType}

object StateStoreColumnFamilySchemaUtils {

def apply(initializeAvroSerde: Boolean): StateStoreColumnFamilySchemaUtils =
new StateStoreColumnFamilySchemaUtils(initializeAvroSerde)


/**
* Avro uses zig-zag encoding for some fixed-length types, like Longs and Ints. For range scans
* we want to use big-endian encoding, so we need to convert the source schema to replace these
* types with BinaryType.
*
* @param schema The schema to convert
* @param ordinals If non-empty, only convert fields at these ordinals.
* If empty, convert all fields.
*/
def convertForRangeScan(schema: StructType, ordinals: Seq[Int] = Seq.empty): StructType = {
val ordinalSet = ordinals.toSet
StructType(schema.fields.zipWithIndex.map { case (field, idx) =>
if ((ordinals.isEmpty || ordinalSet.contains(idx)) && isFixedSize(field.dataType)) {
// Convert numeric types to BinaryType while preserving nullability
field.copy(dataType = BinaryType)
} else {
field
}
})
}

private def isFixedSize(dataType: DataType): Boolean = dataType match {
case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: LongType |
_: FloatType | _: DoubleType => true
case _ => false
}

def getTtlColFamilyName(stateName: String): String = {
"$ttl_" + stateName
}
}

/**
Expand All @@ -43,7 +74,10 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
* for a particular key and value schema.
*/
private def getAvroSerde(
keySchema: StructType, valSchema: StructType): Option[AvroEncoderSpec] = {
keySchema: StructType,
valSchema: StructType,
suffixKeySchema: Option[StructType] = None
): Option[AvroEncoderSpec] = {
if (initializeAvroSerde) {
val avroType = SchemaConverters.toAvroType(valSchema)
val avroOptions = AvroOptions(Map.empty)
Expand All @@ -56,7 +90,18 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
val valueDeserializer = new AvroDeserializer(avroType, valSchema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
Some(AvroEncoderSpec(keySer, keyDe, valueSerializer, valueDeserializer))
val (suffixKeySer, suffixKeyDe) = if (suffixKeySchema.isDefined) {
val userKeyAvroType = SchemaConverters.toAvroType(suffixKeySchema.get)
val skSer = new AvroSerializer(suffixKeySchema.get, userKeyAvroType, nullable = false)
val skDe = new AvroDeserializer(userKeyAvroType, suffixKeySchema.get,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
(Some(skSer), Some(skDe))
} else {
(None, None)
}
Some(AvroEncoderSpec(
keySer, keyDe, valueSerializer, valueDeserializer, suffixKeySer, suffixKeyDe))
} else {
None
}
Expand Down Expand Up @@ -97,12 +142,60 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
valEncoder: Encoder[V],
hasTtl: Boolean): StateStoreColFamilySchema = {
val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema)
val valSchema = getValueSchemaWithTTL(valEncoder.schema, hasTtl)
StateStoreColFamilySchema(
stateName,
compositeKeySchema,
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)),
Some(userKeyEnc.schema))
Some(userKeyEnc.schema),
avroEnc = getAvroSerde(
StructType(compositeKeySchema.take(1)),
valSchema,
Some(StructType(compositeKeySchema.drop(1)))
)
)
}

def getTtlStateSchema(
stateName: String,
keyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = {
val ttlKeySchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan(
getSingleKeyTTLRowSchema(keyEncoder.schema), Seq(0))
val ttlValSchema = StructType(
Array(StructField("__dummy__", NullType)))
StateStoreColFamilySchema(
stateName,
ttlKeySchema,
ttlValSchema,
Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))),
avroEnc = getAvroSerde(
StructType(ttlKeySchema.take(1)),
ttlValSchema,
Some(StructType(ttlKeySchema.drop(1)))
)
)
}

def getTtlStateSchema(
stateName: String,
keyEncoder: ExpressionEncoder[Any],
userKeySchema: StructType): StateStoreColFamilySchema = {
val ttlKeySchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan(
getCompositeKeyTTLRowSchema(keyEncoder.schema, userKeySchema), Seq(0))
val ttlValSchema = StructType(
Array(StructField("__dummy__", NullType)))
StateStoreColFamilySchema(
stateName,
ttlKeySchema,
ttlValSchema,
Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))),
avroEnc = getAvroSerde(
StructType(ttlKeySchema.take(1)),
ttlValSchema,
Some(StructType(ttlKeySchema.drop(1)))
)
)
}

def getTimerStateSchema(
Expand All @@ -113,6 +206,29 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
stateName,
keySchema,
valSchema,
Some(PrefixKeyScanStateEncoderSpec(keySchema, 1)))
Some(PrefixKeyScanStateEncoderSpec(keySchema, 1)),
avroEnc = getAvroSerde(
StructType(keySchema.take(1)),
valSchema,
Some(StructType(keySchema.drop(1)))
))
}

def getTimerStateSchemaForSecIndex(
stateName: String,
keySchema: StructType,
valSchema: StructType): StateStoreColFamilySchema = {
val avroKeySchema = StateStoreColumnFamilySchemaUtils.
convertForRangeScan(keySchema, Seq(0))
StateStoreColFamilySchema(
stateName,
keySchema,
valSchema,
Some(RangeKeyScanStateEncoderSpec(keySchema, Seq(0))),
avroEnc = getAvroSerde(
StructType(avroKeySchema.take(1)),
valSchema,
Some(StructType(avroKeySchema.drop(1)))
))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.PRE_INIT
import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils.getTtlColFamilyName
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, TimeMode, TTLConfig, ValueState}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -140,7 +141,13 @@ class StatefulProcessorHandleImpl(

override def getQueryInfo(): QueryInfo = currQueryInfo

private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder)
private lazy val timerStateName = TimerStateUtils.getTimerStateVarName(
timeMode.toString)
private lazy val timerSecIndexColFamily = TimerStateUtils.getSecIndexColFamilyName(
timeMode.toString)
private lazy val timerState = new TimerStateImpl(
store, timeMode, keyEncoder, schemas(timerStateName).avroEnc,
schemas(timerSecIndexColFamily).avroEnc)

/**
* Function to register a timer for the given expiryTimestampMs
Expand Down Expand Up @@ -323,7 +330,7 @@ class StatefulProcessorHandleImpl(
mapStateWithTTL
} else {
val mapStateWithoutTTL = new MapStateImpl[K, V](store, stateName, keyEncoder,
userKeyEnc, valEncoder, metrics)
userKeyEnc, valEncoder, metrics, schemas(stateName).avroEnc)
TWSMetricsUtils.incrementMetric(metrics, "numMapStateVars")
mapStateWithoutTTL
}
Expand Down Expand Up @@ -382,10 +389,16 @@ class DriverStatefulProcessorHandleImpl(

private def addTimerColFamily(): Unit = {
val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString)
val secIndexColFamilyName = TimerStateUtils.getSecIndexColFamilyName(timeMode.toString)
val timerEncoder = new TimerKeyEncoder(keyExprEnc)
val colFamilySchema = schemaUtils.
getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow)
val secIndexColFamilySchema = schemaUtils.
getTimerStateSchemaForSecIndex(secIndexColFamilyName,
timerEncoder.keySchemaForSecIndex,
timerEncoder.schemaForValueRow)
columnFamilySchemas.put(stateName, colFamilySchema)
columnFamilySchemas.put(secIndexColFamilyName, secIndexColFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.getTimerState(stateName)
stateVariableInfos.put(stateName, stateVariableInfo)
}
Expand All @@ -404,6 +417,9 @@ class DriverStatefulProcessorHandleImpl(
val ttlEnabled = if (ttlConfig.ttlDuration != null && ttlConfig.ttlDuration.isZero) {
false
} else {
val ttlColFamilyName = getTtlColFamilyName(stateName)
val ttlColFamilySchema = schemaUtils.getTtlStateSchema(ttlColFamilyName, keyExprEnc)
columnFamilySchemas.put(ttlColFamilyName, ttlColFamilySchema)
true
}

Expand Down Expand Up @@ -432,6 +448,9 @@ class DriverStatefulProcessorHandleImpl(
val ttlEnabled = if (ttlConfig.ttlDuration != null && ttlConfig.ttlDuration.isZero) {
false
} else {
val ttlColFamilyName = getTtlColFamilyName(stateName)
val ttlColFamilySchema = schemaUtils.getTtlStateSchema(ttlColFamilyName, keyExprEnc)
columnFamilySchemas.put(ttlColFamilyName, ttlColFamilySchema)
true
}

Expand Down Expand Up @@ -459,14 +478,19 @@ class DriverStatefulProcessorHandleImpl(
ttlConfig: TTLConfig): MapState[K, V] = {
verifyStateVarOperations("get_map_state", PRE_INIT)

val userKeyEnc = encoderFor[K]
val valEncoder = encoderFor[V]
val ttlEnabled = if (ttlConfig.ttlDuration != null && ttlConfig.ttlDuration.isZero) {
false
} else {
val ttlColFamilyName = getTtlColFamilyName(stateName)
val ttlColFamilySchema = schemaUtils.getTtlStateSchema(
ttlColFamilyName, keyExprEnc, userKeyEnc.schema)
columnFamilySchemas.put(ttlColFamilyName, ttlColFamilySchema)
true
}

val userKeyEnc = encoderFor[K]
val valEncoder = encoderFor[V]

val colFamilySchema = schemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, ttlEnabled)
columnFamilySchemas.put(stateName, colFamilySchema)
Expand Down
Loading
Loading