Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
stateName,
keyExprEnc.schema
)
null.asInstanceOf[ValueState[T]]
new InvalidHandleValueState[T](stateName)
}

override def getListState[T](
Expand Down Expand Up @@ -492,7 +492,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
stateName,
keyExprEnc.schema
)
null.asInstanceOf[ListState[T]]
new InvalidHandleListState[T](stateName)
}

override def getMapState[K, V](
Expand Down Expand Up @@ -522,7 +522,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val stateVariableInfo = TransformWithStateVariableUtils.
getMapState(stateName, ttlEnabled = ttlEnabled)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[MapState[K, V]]
new InvalidHandleMapState[K, V](stateName)
}

/**
Expand Down Expand Up @@ -655,3 +655,43 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
verifyStateVarOperations("delete_if_exists", PRE_INIT)
}
}

private[sql] trait InvalidHandleState {
protected val stateName: String

protected def throwInitPhaseError(operation: String): Nothing = {
throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(
s"$stateName.$operation", PRE_INIT.toString)
}
}

private[sql] class InvalidHandleValueState[S](override val stateName: String)
extends ValueState[S] with InvalidHandleState {
override def exists(): Boolean = throwInitPhaseError("exists")
override def get(): S = throwInitPhaseError("get")
override def update(newState: S): Unit = throwInitPhaseError("update")
override def clear(): Unit = throwInitPhaseError("clear")
}

private[sql] class InvalidHandleListState[S](override val stateName: String)
extends ListState[S] with InvalidHandleState {
override def exists(): Boolean = throwInitPhaseError("exists")
override def get(): Iterator[S] = throwInitPhaseError("get")
override def put(newState: Array[S]): Unit = throwInitPhaseError("put")
override def appendValue(newState: S): Unit = throwInitPhaseError("appendValue")
override def appendList(newState: Array[S]): Unit = throwInitPhaseError("appendList")
override def clear(): Unit = throwInitPhaseError("clear")
}

private[sql] class InvalidHandleMapState[K, V](override val stateName: String)
extends MapState[K, V] with InvalidHandleState {
override def exists(): Boolean = throwInitPhaseError("exists")
override def getValue(key: K): V = throwInitPhaseError("getValue")
override def containsKey(key: K): Boolean = throwInitPhaseError("containsKey")
override def updateValue(key: K, value: V): Unit = throwInitPhaseError("updateValue")
override def iterator(): Iterator[(K, V)] = throwInitPhaseError("iterator")
override def keys(): Iterator[K] = throwInitPhaseError("keys")
override def values(): Iterator[V] = throwInitPhaseError("values")
override def removeKey(key: K): Unit = throwInitPhaseError("removeKey")
override def clear(): Unit = throwInitPhaseError("clear")
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
import java.time.Duration

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.execution.streaming.{ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImplWithTTL}
import org.apache.spark.sql.execution.streaming.{ListStateImplWithTTL, MemoryStream}
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
Expand All @@ -36,25 +36,22 @@ import org.apache.spark.sql.streaming.util.StreamManualClock
// used to add a record into the secondary index for every appendList call.
class MultiStatefulVariableTTLProcessor(ttlConfig: TTLConfig)
extends StatefulProcessor[String, String, (String, Long)]{
@transient private var _listState: ListStateImplWithTTL[String] = _
@transient private var _listState: ListState[String] = _
// Map from index to count
@transient private var _mapState: MapStateImplWithTTL[Long, Long] = _
@transient private var _mapState: MapState[Long, Long] = _
// Counts the number of times the string has occurred. It should always be
// equal to the size of the list state at the start and end of handleInputRows.
@transient private var _valueState: ValueStateImplWithTTL[Long] = _
@transient private var _valueState: ValueState[Long] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_listState = getHandle
.getListState("listState", Encoders.STRING, ttlConfig)
.asInstanceOf[ListStateImplWithTTL[String]]
_mapState = getHandle
.getMapState("mapState", Encoders.scalaLong, Encoders.scalaLong, ttlConfig)
.asInstanceOf[MapStateImplWithTTL[Long, Long]]
_valueState = getHandle
.getValueState("valueState", Encoders.scalaLong, ttlConfig)
.asInstanceOf[ValueStateImplWithTTL[Long]]
}
override def handleInputRows(
key: String,
Expand Down Expand Up @@ -94,14 +91,13 @@ class MultiStatefulVariableTTLProcessor(ttlConfig: TTLConfig)
class ListStateTTLProcessor(ttlConfig: TTLConfig)
extends StatefulProcessor[String, InputEvent, OutputEvent] {

@transient private var _listState: ListStateImplWithTTL[Int] = _
@transient private var _listState: ListState[Int] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_listState = getHandle
.getListState("listState", Encoders.scalaInt, ttlConfig)
.asInstanceOf[ListStateImplWithTTL[Int]]
}

override def handleInputRows(
Expand All @@ -111,7 +107,8 @@ class ListStateTTLProcessor(ttlConfig: TTLConfig)
var results = List[OutputEvent]()

inputRows.foreach { row =>
val resultIter = processRow(row, _listState)
val resultIter = processRow(row,
_listState.asInstanceOf[ListStateImplWithTTL[Int]])
resultIter.foreach { r =>
results = r :: results
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ import org.apache.spark.sql.streaming.util.StreamManualClock
class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
extends StatefulProcessor[String, InputEvent, OutputEvent] {

@transient private var _mapState: MapStateImplWithTTL[String, Int] = _
@transient private var _mapState: MapState[String, Int] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_mapState = getHandle
.getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
.asInstanceOf[MapStateImplWithTTL[String, Int]]
}

override def handleInputRows(
Expand All @@ -45,7 +44,7 @@ class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
var results = List[OutputEvent]()

for (row <- inputRows) {
val resultIter = processRow(row, _mapState)
val resultIter = processRow(row, _mapState.asInstanceOf[MapStateImplWithTTL[String, Int]])
resultIter.foreach { r =>
results = r :: results
}
Expand Down Expand Up @@ -107,14 +106,14 @@ case class MapOutputEvent(
class MapStateTTLProcessor(ttlConfig: TTLConfig)
extends StatefulProcessor[String, MapInputEvent, MapOutputEvent] {

@transient private var _mapState: MapStateImplWithTTL[String, Int] = _
@transient private var _mapState: MapState[String, Int] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_mapState = getHandle
.getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
.asInstanceOf[MapStateImplWithTTL[String, Int]]

}

override def handleInputRows(
Expand All @@ -124,7 +123,7 @@ class MapStateTTLProcessor(ttlConfig: TTLConfig)
var results = List[MapOutputEvent]()

for (row <- inputRows) {
val resultIter = processRow(row, _mapState)
val resultIter = processRow(row, _mapState.asInstanceOf[MapStateImplWithTTL[String, Int]])
resultIter.foreach { r =>
results = r :: results
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,63 @@ class SleepingTimerProcessor extends StatefulProcessor[String, String, String] {
}
}

class TestMapStateExistsInInit extends StatefulProcessor[String, String, String] {
@transient var _mapState: MapState[String, String] = _

override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_mapState = getHandle.getMapState[String, String](
"mapState", Encoders.STRING, Encoders.STRING, TTLConfig.NONE)

// This should fail as we can't call exists() during init
val exists = _mapState.exists()
}

override def handleInputRows(
key: String,
rows: Iterator[String],
timerValues: TimerValues): Iterator[String] = {
Iterator.empty
}
}

class TestValueStateExistsInInit extends StatefulProcessor[String, String, String] {
@transient var _valueState: ValueState[String] = _

override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_valueState = getHandle.getValueState[String](
"valueState", Encoders.STRING, TTLConfig.NONE)

// This should fail as we can't call exists() during init
val exists = _valueState.exists()
}

override def handleInputRows(
key: String,
rows: Iterator[String],
timerValues: TimerValues): Iterator[String] = {
Iterator.empty
}
}

class TestListStateExistsInInit extends StatefulProcessor[String, String, String] {
@transient var _listState: ListState[String] = _

override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_listState = getHandle.getListState[String](
"listState", Encoders.STRING, TTLConfig.NONE)

// This should fail as we can't call exists() during init
val exists = _listState.exists()
}

override def handleInputRows(
key: String,
rows: Iterator[String],
timerValues: TimerValues): Iterator[String] = {
Iterator.empty
}
}

/**
* Class that adds tests for transformWithState stateful streaming operator
*/
Expand Down Expand Up @@ -2343,6 +2400,90 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest {
)
}

test("transformWithState - ValueState.exists() should fail in init") {
withSQLConf(
SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {

val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new TestValueStateExistsInInit(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
AddData(inputData, "a"),
ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error =>
checkError(
error.asInstanceOf[SparkUnsupportedOperationException],
condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE",
parameters = Map(
"operationType" -> "valueState.exists",
"handleState" -> "PRE_INIT")
)
}
)
}
}

test("transformWithState - MapState.exists() should fail in init") {
withSQLConf(
SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {

val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new TestMapStateExistsInInit(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
AddData(inputData, "a"),
ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error =>
checkError(
error.asInstanceOf[SparkUnsupportedOperationException],
condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE",
parameters = Map(
"operationType" -> "mapState.exists",
"handleState" -> "PRE_INIT")
)
}
)
}
}

test("transformWithState - ListState.exists() should fail in init") {
withSQLConf(
SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {

val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new TestListStateExistsInInit(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
AddData(inputData, "a"),
ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error =>
checkError(
error.asInstanceOf[SparkUnsupportedOperationException],
condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE",
parameters = Map(
"operationType" -> "listState.exists",
"handleState" -> "PRE_INIT")
)
}
)
}
}

test("transformWithStateWithInitialState - streaming with hdfsStateStoreProvider should fail") {
val inputData = MemoryStream[InitInputRow]
val initDf = Seq(("init_1", 40.0), ("init_2", 100.0)).toDS()
Expand Down
Loading