From 700d882e3643a473f19c1a7c2082cacc794bedae Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 16 Apr 2025 11:52:48 -0700 Subject: [PATCH 1/6] Throwing classified error when disallowed functions are called during StatefulProcessor.init() --- .../streaming/DriverStateState.scala | 63 ++++++++ .../StatefulProcessorHandleImpl.scala | 6 +- .../streaming/TransformWithStateSuite.scala | 141 ++++++++++++++++++ 3 files changed, 207 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverStateState.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverStateState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverStateState.scala new file mode 100644 index 000000000000..c1438990fe46 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverStateState.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.PRE_INIT +import org.apache.spark.sql.execution.streaming.state.StateStoreErrors +import org.apache.spark.sql.streaming.{ListState, MapState, ValueState} + +// First, let's create a common base trait for all driver-side state implementations +trait DriverSideState { + protected val stateName: String + + protected def throwInitPhaseError(operation: String): Nothing = { + throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState( + s"$stateName.$operation", PRE_INIT.toString) + } +} + +// Then implement each state type +class DriverSideValueState[S](override val stateName: String) + extends ValueState[S] with DriverSideState { + override def exists(): Boolean = throwInitPhaseError("exists") + override def get(): S = throwInitPhaseError("get") + override def update(newState: S): Unit = throwInitPhaseError("update") + override def clear(): Unit = throwInitPhaseError("clear") +} + +class DriverSideListState[S](override val stateName: String) + extends ListState[S] with DriverSideState { + override def exists(): Boolean = throwInitPhaseError("exists") + override def get(): Iterator[S] = throwInitPhaseError("get") + override def put(newState: Array[S]): Unit = throwInitPhaseError("put") + override def appendValue(newState: S): Unit = throwInitPhaseError("appendValue") + override def appendList(newState: Array[S]): Unit = throwInitPhaseError("appendList") + override def clear(): Unit = throwInitPhaseError("clear") +} + +class DriverSideMapState[K, V](override val stateName: String) + extends MapState[K, V] with DriverSideState { + override def exists(): Boolean = throwInitPhaseError("exists") + override def getValue(key: K): V = throwInitPhaseError("getValue") + override def containsKey(key: K): Boolean = throwInitPhaseError("containsKey") + override def updateValue(key: K, value: V): Unit = throwInitPhaseError("updateValue") + override def iterator(): Iterator[(K, V)] = throwInitPhaseError("iterator") + override def keys(): Iterator[K] = throwInitPhaseError("keys") + override def values(): Iterator[V] = throwInitPhaseError("values") + override def removeKey(key: K): Unit = throwInitPhaseError("removeKey") + override def clear(): Unit = throwInitPhaseError("clear") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index f1f0ddf206c6..d834a4a93023 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -458,7 +458,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi stateName, keyExprEnc.schema ) - null.asInstanceOf[ValueState[T]] + new DriverSideValueState[T](stateName) } override def getListState[T]( @@ -492,7 +492,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi stateName, keyExprEnc.schema ) - null.asInstanceOf[ListState[T]] + new DriverSideListState[T](stateName) } override def getMapState[K, V]( @@ -522,7 +522,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi val stateVariableInfo = TransformWithStateVariableUtils. getMapState(stateName, ttlEnabled = ttlEnabled) stateVariableInfos.put(stateName, stateVariableInfo) - null.asInstanceOf[MapState[K, V]] + new DriverSideMapState[K, V](stateName) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 6f1da588eb53..d91cc8f32a79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -764,6 +764,63 @@ class SleepingTimerProcessor extends StatefulProcessor[String, String, String] { } } +class TestMapStateExistsInInit extends StatefulProcessor[String, String, String] { + @transient var _mapState: MapState[String, String] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + _mapState = getHandle.getMapState[String, String]( + "mapState", Encoders.STRING, Encoders.STRING, TTLConfig.NONE) + + // This should fail as we can't call exists() during init + val exists = _mapState.exists() + } + + override def handleInputRows( + key: String, + rows: Iterator[String], + timerValues: TimerValues): Iterator[String] = { + Iterator.empty + } +} + +class TestValueStateExistsInInit extends StatefulProcessor[String, String, String] { + @transient var _valueState: ValueState[String] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + _valueState = getHandle.getValueState[String]( + "valueState", Encoders.STRING, TTLConfig.NONE) + + // This should fail as we can't call exists() during init + val exists = _valueState.exists() + } + + override def handleInputRows( + key: String, + rows: Iterator[String], + timerValues: TimerValues): Iterator[String] = { + Iterator.empty + } +} + +class TestListStateExistsInInit extends StatefulProcessor[String, String, String] { + @transient var _listState: ListState[String] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + _listState = getHandle.getListState[String]( + "listState", Encoders.STRING, TTLConfig.NONE) + + // This should fail as we can't call exists() during init + val exists = _listState.exists() + } + + override def handleInputRows( + key: String, + rows: Iterator[String], + timerValues: TimerValues): Iterator[String] = { + Iterator.empty + } +} + /** * Class that adds tests for transformWithState stateful streaming operator */ @@ -872,6 +929,90 @@ abstract class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - ValueState.exists() should fail in init") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new TestValueStateExistsInInit(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error => + checkError( + error.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", + parameters = Map( + "operationType" -> "valueState.exists", + "handleState" -> "PRE_INIT") + ) + } + ) + } + } + + test("transformWithState - MapState.exists() should fail in init") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new TestMapStateExistsInInit(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error => + checkError( + error.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", + parameters = Map( + "operationType" -> "mapState.exists", + "handleState" -> "PRE_INIT") + ) + } + ) + } + } + + test("transformWithState - ListState.exists() should fail in init") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new TestListStateExistsInInit(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error => + checkError( + error.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", + parameters = Map( + "operationType" -> "listState.exists", + "handleState" -> "PRE_INIT") + ) + } + ) + } + } + test("transformWithState - streaming with rocksdb should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, From 3d7ac47522d314941f2eab690d06d1efaf520b78 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 16 Apr 2025 12:11:15 -0700 Subject: [PATCH 2/6] removing comments --- .../streaming/{DriverStateState.scala => DriverSideState.scala} | 2 -- 1 file changed, 2 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{DriverStateState.scala => DriverSideState.scala} (96%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverStateState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverSideState.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverStateState.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverSideState.scala index c1438990fe46..ff9e47b0aa0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverStateState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverSideState.scala @@ -20,7 +20,6 @@ import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.PRE import org.apache.spark.sql.execution.streaming.state.StateStoreErrors import org.apache.spark.sql.streaming.{ListState, MapState, ValueState} -// First, let's create a common base trait for all driver-side state implementations trait DriverSideState { protected val stateName: String @@ -30,7 +29,6 @@ trait DriverSideState { } } -// Then implement each state type class DriverSideValueState[S](override val stateName: String) extends ValueState[S] with DriverSideState { override def exists(): Boolean = throwInitPhaseError("exists") From e6699c0f31d56e2912842a83b32bccafd9efdeeb Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 16 Apr 2025 13:34:02 -0700 Subject: [PATCH 3/6] moving to statefulprocessorhandleimpl --- .../execution/streaming/DriverSideState.scala | 61 ------------------- .../StatefulProcessorHandleImpl.scala | 40 ++++++++++++ 2 files changed, 40 insertions(+), 61 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverSideState.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverSideState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverSideState.scala deleted file mode 100644 index ff9e47b0aa0d..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/DriverSideState.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.PRE_INIT -import org.apache.spark.sql.execution.streaming.state.StateStoreErrors -import org.apache.spark.sql.streaming.{ListState, MapState, ValueState} - -trait DriverSideState { - protected val stateName: String - - protected def throwInitPhaseError(operation: String): Nothing = { - throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState( - s"$stateName.$operation", PRE_INIT.toString) - } -} - -class DriverSideValueState[S](override val stateName: String) - extends ValueState[S] with DriverSideState { - override def exists(): Boolean = throwInitPhaseError("exists") - override def get(): S = throwInitPhaseError("get") - override def update(newState: S): Unit = throwInitPhaseError("update") - override def clear(): Unit = throwInitPhaseError("clear") -} - -class DriverSideListState[S](override val stateName: String) - extends ListState[S] with DriverSideState { - override def exists(): Boolean = throwInitPhaseError("exists") - override def get(): Iterator[S] = throwInitPhaseError("get") - override def put(newState: Array[S]): Unit = throwInitPhaseError("put") - override def appendValue(newState: S): Unit = throwInitPhaseError("appendValue") - override def appendList(newState: Array[S]): Unit = throwInitPhaseError("appendList") - override def clear(): Unit = throwInitPhaseError("clear") -} - -class DriverSideMapState[K, V](override val stateName: String) - extends MapState[K, V] with DriverSideState { - override def exists(): Boolean = throwInitPhaseError("exists") - override def getValue(key: K): V = throwInitPhaseError("getValue") - override def containsKey(key: K): Boolean = throwInitPhaseError("containsKey") - override def updateValue(key: K, value: V): Unit = throwInitPhaseError("updateValue") - override def iterator(): Iterator[(K, V)] = throwInitPhaseError("iterator") - override def keys(): Iterator[K] = throwInitPhaseError("keys") - override def values(): Iterator[V] = throwInitPhaseError("values") - override def removeKey(key: K): Unit = throwInitPhaseError("removeKey") - override def clear(): Unit = throwInitPhaseError("clear") -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index d834a4a93023..9936755f1317 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -655,3 +655,43 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("delete_if_exists", PRE_INIT) } } + +private[sql] trait DriverSideState { + protected val stateName: String + + protected def throwInitPhaseError(operation: String): Nothing = { + throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState( + s"$stateName.$operation", PRE_INIT.toString) + } +} + +private[sql] class DriverSideValueState[S](override val stateName: String) + extends ValueState[S] with DriverSideState { + override def exists(): Boolean = throwInitPhaseError("exists") + override def get(): S = throwInitPhaseError("get") + override def update(newState: S): Unit = throwInitPhaseError("update") + override def clear(): Unit = throwInitPhaseError("clear") +} + +private[sql] class DriverSideListState[S](override val stateName: String) + extends ListState[S] with DriverSideState { + override def exists(): Boolean = throwInitPhaseError("exists") + override def get(): Iterator[S] = throwInitPhaseError("get") + override def put(newState: Array[S]): Unit = throwInitPhaseError("put") + override def appendValue(newState: S): Unit = throwInitPhaseError("appendValue") + override def appendList(newState: Array[S]): Unit = throwInitPhaseError("appendList") + override def clear(): Unit = throwInitPhaseError("clear") +} + +private[sql] class DriverSideMapState[K, V](override val stateName: String) + extends MapState[K, V] with DriverSideState { + override def exists(): Boolean = throwInitPhaseError("exists") + override def getValue(key: K): V = throwInitPhaseError("getValue") + override def containsKey(key: K): Boolean = throwInitPhaseError("containsKey") + override def updateValue(key: K, value: V): Unit = throwInitPhaseError("updateValue") + override def iterator(): Iterator[(K, V)] = throwInitPhaseError("iterator") + override def keys(): Iterator[K] = throwInitPhaseError("keys") + override def values(): Iterator[V] = throwInitPhaseError("values") + override def removeKey(key: K): Unit = throwInitPhaseError("removeKey") + override def clear(): Unit = throwInitPhaseError("clear") +} From 3b06f6710462cff3265d06d05380d80bf0c1d364 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 16 Apr 2025 13:55:03 -0700 Subject: [PATCH 4/6] moving to validationsuite --- .../streaming/TransformWithStateSuite.scala | 168 +++++++++--------- 1 file changed, 84 insertions(+), 84 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index d91cc8f32a79..bda26da88679 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -929,90 +929,6 @@ abstract class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - ValueState.exists() should fail in init") { - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new TestValueStateExistsInInit(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result, OutputMode.Update())( - AddData(inputData, "a"), - ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error => - checkError( - error.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", - parameters = Map( - "operationType" -> "valueState.exists", - "handleState" -> "PRE_INIT") - ) - } - ) - } - } - - test("transformWithState - MapState.exists() should fail in init") { - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new TestMapStateExistsInInit(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result, OutputMode.Update())( - AddData(inputData, "a"), - ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error => - checkError( - error.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", - parameters = Map( - "operationType" -> "mapState.exists", - "handleState" -> "PRE_INIT") - ) - } - ) - } - } - - test("transformWithState - ListState.exists() should fail in init") { - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new TestListStateExistsInInit(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result, OutputMode.Update())( - AddData(inputData, "a"), - ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error => - checkError( - error.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", - parameters = Map( - "operationType" -> "listState.exists", - "handleState" -> "PRE_INIT") - ) - } - ) - } - } - test("transformWithState - streaming with rocksdb should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -2484,6 +2400,90 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { ) } + test("transformWithState - ValueState.exists() should fail in init") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new TestValueStateExistsInInit(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error => + checkError( + error.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", + parameters = Map( + "operationType" -> "valueState.exists", + "handleState" -> "PRE_INIT") + ) + } + ) + } + } + + test("transformWithState - MapState.exists() should fail in init") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new TestMapStateExistsInInit(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error => + checkError( + error.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", + parameters = Map( + "operationType" -> "mapState.exists", + "handleState" -> "PRE_INIT") + ) + } + ) + } + } + + test("transformWithState - ListState.exists() should fail in init") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new TestListStateExistsInInit(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { error => + checkError( + error.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", + parameters = Map( + "operationType" -> "listState.exists", + "handleState" -> "PRE_INIT") + ) + } + ) + } + } + test("transformWithStateWithInitialState - streaming with hdfsStateStoreProvider should fail") { val inputData = MemoryStream[InitInputRow] val initDf = Seq(("init_1", 40.0), ("init_2", 100.0)).toDS() From 7f28f15b89bed975b2eab4049ab6ac0691259c2f Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 16 Apr 2025 15:18:23 -0700 Subject: [PATCH 5/6] renaming --- .../StatefulProcessorHandleImpl.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 9936755f1317..3cf3286fafb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -458,7 +458,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi stateName, keyExprEnc.schema ) - new DriverSideValueState[T](stateName) + new InvalidHandleValueState[T](stateName) } override def getListState[T]( @@ -492,7 +492,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi stateName, keyExprEnc.schema ) - new DriverSideListState[T](stateName) + new InvalidHandleListState[T](stateName) } override def getMapState[K, V]( @@ -522,7 +522,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi val stateVariableInfo = TransformWithStateVariableUtils. getMapState(stateName, ttlEnabled = ttlEnabled) stateVariableInfos.put(stateName, stateVariableInfo) - new DriverSideMapState[K, V](stateName) + new InvalidHandleMapState[K, V](stateName) } /** @@ -656,7 +656,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi } } -private[sql] trait DriverSideState { +private[sql] trait InvalidHandleState { protected val stateName: String protected def throwInitPhaseError(operation: String): Nothing = { @@ -665,16 +665,16 @@ private[sql] trait DriverSideState { } } -private[sql] class DriverSideValueState[S](override val stateName: String) - extends ValueState[S] with DriverSideState { +private[sql] class InvalidHandleValueState[S](override val stateName: String) + extends ValueState[S] with InvalidHandleState { override def exists(): Boolean = throwInitPhaseError("exists") override def get(): S = throwInitPhaseError("get") override def update(newState: S): Unit = throwInitPhaseError("update") override def clear(): Unit = throwInitPhaseError("clear") } -private[sql] class DriverSideListState[S](override val stateName: String) - extends ListState[S] with DriverSideState { +private[sql] class InvalidHandleListState[S](override val stateName: String) + extends ListState[S] with InvalidHandleState { override def exists(): Boolean = throwInitPhaseError("exists") override def get(): Iterator[S] = throwInitPhaseError("get") override def put(newState: Array[S]): Unit = throwInitPhaseError("put") @@ -683,8 +683,8 @@ private[sql] class DriverSideListState[S](override val stateName: String) override def clear(): Unit = throwInitPhaseError("clear") } -private[sql] class DriverSideMapState[K, V](override val stateName: String) - extends MapState[K, V] with DriverSideState { +private[sql] class InvalidHandleMapState[K, V](override val stateName: String) + extends MapState[K, V] with InvalidHandleState { override def exists(): Boolean = throwInitPhaseError("exists") override def getValue(key: K): V = throwInitPhaseError("getValue") override def containsKey(key: K): Boolean = throwInitPhaseError("containsKey") From 751d807fc95f2c2921068f23c1102d368852ebb5 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 22 Apr 2025 09:52:57 -0700 Subject: [PATCH 6/6] removing class cast from init function --- .../TransformWithListStateTTLSuite.scala | 17 ++++++------- .../TransformWithMapStateTTLSuite.scala | 11 ++++---- .../TransformWithValueStateTTLSuite.scala | 25 ++++++++----------- 3 files changed, 23 insertions(+), 30 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala index d04573becf1a..bd3667b16591 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import java.time.Duration import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.{ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.{ListStateImplWithTTL, MemoryStream} import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -36,25 +36,22 @@ import org.apache.spark.sql.streaming.util.StreamManualClock // used to add a record into the secondary index for every appendList call. class MultiStatefulVariableTTLProcessor(ttlConfig: TTLConfig) extends StatefulProcessor[String, String, (String, Long)]{ - @transient private var _listState: ListStateImplWithTTL[String] = _ + @transient private var _listState: ListState[String] = _ // Map from index to count - @transient private var _mapState: MapStateImplWithTTL[Long, Long] = _ + @transient private var _mapState: MapState[Long, Long] = _ // Counts the number of times the string has occurred. It should always be // equal to the size of the list state at the start and end of handleInputRows. - @transient private var _valueState: ValueStateImplWithTTL[Long] = _ + @transient private var _valueState: ValueState[Long] = _ override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { _listState = getHandle .getListState("listState", Encoders.STRING, ttlConfig) - .asInstanceOf[ListStateImplWithTTL[String]] _mapState = getHandle .getMapState("mapState", Encoders.scalaLong, Encoders.scalaLong, ttlConfig) - .asInstanceOf[MapStateImplWithTTL[Long, Long]] _valueState = getHandle .getValueState("valueState", Encoders.scalaLong, ttlConfig) - .asInstanceOf[ValueStateImplWithTTL[Long]] } override def handleInputRows( key: String, @@ -94,14 +91,13 @@ class MultiStatefulVariableTTLProcessor(ttlConfig: TTLConfig) class ListStateTTLProcessor(ttlConfig: TTLConfig) extends StatefulProcessor[String, InputEvent, OutputEvent] { - @transient private var _listState: ListStateImplWithTTL[Int] = _ + @transient private var _listState: ListState[Int] = _ override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { _listState = getHandle .getListState("listState", Encoders.scalaInt, ttlConfig) - .asInstanceOf[ListStateImplWithTTL[Int]] } override def handleInputRows( @@ -111,7 +107,8 @@ class ListStateTTLProcessor(ttlConfig: TTLConfig) var results = List[OutputEvent]() inputRows.foreach { row => - val resultIter = processRow(row, _listState) + val resultIter = processRow(row, + _listState.asInstanceOf[ListStateImplWithTTL[Int]]) resultIter.foreach { r => results = r :: results } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala index 2cb15263459e..c845059f20fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala @@ -28,14 +28,13 @@ import org.apache.spark.sql.streaming.util.StreamManualClock class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig) extends StatefulProcessor[String, InputEvent, OutputEvent] { - @transient private var _mapState: MapStateImplWithTTL[String, Int] = _ + @transient private var _mapState: MapState[String, Int] = _ override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { _mapState = getHandle .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig) - .asInstanceOf[MapStateImplWithTTL[String, Int]] } override def handleInputRows( @@ -45,7 +44,7 @@ class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig) var results = List[OutputEvent]() for (row <- inputRows) { - val resultIter = processRow(row, _mapState) + val resultIter = processRow(row, _mapState.asInstanceOf[MapStateImplWithTTL[String, Int]]) resultIter.foreach { r => results = r :: results } @@ -107,14 +106,14 @@ case class MapOutputEvent( class MapStateTTLProcessor(ttlConfig: TTLConfig) extends StatefulProcessor[String, MapInputEvent, MapOutputEvent] { - @transient private var _mapState: MapStateImplWithTTL[String, Int] = _ + @transient private var _mapState: MapState[String, Int] = _ override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { _mapState = getHandle .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig) - .asInstanceOf[MapStateImplWithTTL[String, Int]] + } override def handleInputRows( @@ -124,7 +123,7 @@ class MapStateTTLProcessor(ttlConfig: TTLConfig) var results = List[MapOutputEvent]() for (row <- inputRows) { - val resultIter = processRow(row, _mapState) + val resultIter = processRow(row, _mapState.asInstanceOf[MapStateImplWithTTL[String, Int]]) resultIter.foreach { r => results = r :: results } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 4c682b18eef8..2b33b3feb307 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MemoryStream, ValueStateImpl, ValueStateImplWithTTL} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -90,14 +90,13 @@ class ValueStateTTLProcessor(ttlConfig: TTLConfig) extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient private var _valueState: ValueStateImplWithTTL[Int] = _ + @transient private var _valueState: ValueState[Int] = _ override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { _valueState = getHandle .getValueState("valueState", Encoders.scalaInt, ttlConfig) - .asInstanceOf[ValueStateImplWithTTL[Int]] } override def handleInputRows( @@ -107,7 +106,8 @@ class ValueStateTTLProcessor(ttlConfig: TTLConfig) var results = List[OutputEvent]() inputRows.foreach { row => - val resultIter = TTLInputProcessFunction.processRow(row, _valueState) + val resultIter = TTLInputProcessFunction.processRow(row, + _valueState.asInstanceOf[ValueStateImplWithTTL[Int]]) resultIter.foreach { r => results = r :: results } @@ -124,18 +124,16 @@ class MultipleValueStatesTTLProcessor( extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _ - @transient var _valueStateWithoutTTL: ValueStateImpl[Int] = _ + @transient var _valueStateWithTTL: ValueState[Int] = _ + @transient var _valueStateWithoutTTL: ValueState[Int] = _ override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { _valueStateWithTTL = getHandle .getValueState("valueStateTTL", Encoders.scalaInt, ttlConfig) - .asInstanceOf[ValueStateImplWithTTL[Int]] _valueStateWithoutTTL = getHandle .getValueState[Int]("valueState", Encoders.scalaInt, TTLConfig.NONE) - .asInstanceOf[ValueStateImpl[Int]] } override def handleInputRows( @@ -146,7 +144,8 @@ class MultipleValueStatesTTLProcessor( if (key == ttlKey) { inputRows.foreach { row => - val resultIterator = TTLInputProcessFunction.processRow(row, _valueStateWithTTL) + val resultIterator = TTLInputProcessFunction.processRow(row, + _valueStateWithTTL.asInstanceOf[ValueStateImplWithTTL[Int]]) resultIterator.foreach { r => results = r :: results } @@ -154,7 +153,7 @@ class MultipleValueStatesTTLProcessor( } else { inputRows.foreach { row => val resultIterator = TTLInputProcessFunction.processNonTTLStateRow(row, - _valueStateWithoutTTL) + _valueStateWithoutTTL.asInstanceOf[ValueStateImpl[Int]]) resultIterator.foreach { r => results = r :: results } @@ -171,8 +170,8 @@ class TTLProcessorWithCompositeTypes( noTtlKey: String, ttlConfig: TTLConfig) extends MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig) { - @transient private var _listStateWithTTL: ListStateImplWithTTL[TestClass] = _ - @transient private var _mapStateWithTTL: MapStateImplWithTTL[POJOTestClass, String] = _ + @transient private var _listStateWithTTL: ListState[TestClass] = _ + @transient private var _mapStateWithTTL: MapState[POJOTestClass, String] = _ override def init( outputMode: OutputMode, @@ -180,11 +179,9 @@ class TTLProcessorWithCompositeTypes( super.init(outputMode, timeMode) _listStateWithTTL = getHandle .getListState("listState", Encoders.product[TestClass], ttlConfig) - .asInstanceOf[ListStateImplWithTTL[TestClass]] _mapStateWithTTL = getHandle .getMapState("mapState", Encoders.bean(classOf[POJOTestClass]), Encoders.STRING, ttlConfig) - .asInstanceOf[MapStateImplWithTTL[POJOTestClass, String]] } }