diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ca2f6dd7a84b..73541c22c630 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -383,8 +383,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( - f, key, value, grouping, data, output, _, _, _, _, child) => - execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil + f, key, value, grouping, data, output, _, _, _, timeout, child) => + execution.MapGroupsExec( + f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 48c7b80bffe0..34391818f3b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.streaming.GroupStateTimeout import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -361,8 +362,11 @@ object MapGroupsExec { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, + timeoutConf: GroupStateTimeout, child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None)) + val f = (key: Any, values: Iterator[Any]) => { + func(key, values, GroupStateImpl.createForBatch(timeoutConf)) + } new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index bd8d5d7b43d3..3ceb4cf84a41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -215,7 +215,7 @@ case class FlatMapGroupsWithStateExec( val keyObj = getKeyObj(keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects val stateObjOption = getStateObj(prevStateRowOption) - val keyedState = new GroupStateImpl( + val keyedState = GroupStateImpl.createForStreaming( stateObjOption, batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index d4606fd5a846..4401e86936af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -38,20 +38,13 @@ import org.apache.spark.unsafe.types.CalendarInterval * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ -private[sql] class GroupStateImpl[S]( +private[sql] class GroupStateImpl[S] private( optionalValue: Option[S], batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, override val hasTimedOut: Boolean) extends GroupState[S] { - // Constructor to create dummy state when using mapGroupsWithState in a batch query - def this(optionalValue: Option[S]) = this( - optionalValue, - batchProcessingTimeMs = NO_TIMESTAMP, - eventTimeWatermarkMs = NO_TIMESTAMP, - timeoutConf = GroupStateTimeout.NoTimeout, - hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined private var updated: Boolean = false // whether value has been updated (but not removed) @@ -102,12 +95,7 @@ private[sql] class GroupStateImpl[S]( if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") } - if (batchProcessingTimeMs != NO_TIMESTAMP) { - timeoutTimestamp = durationMs + batchProcessingTimeMs - } else { - // This is being called in a batch query, hence no processing timestamp. - // Just ignore any attempts to set timeout. - } + timeoutTimestamp = durationMs + batchProcessingTimeMs } override def setTimeoutDuration(duration: String): Unit = { @@ -128,12 +116,7 @@ private[sql] class GroupStateImpl[S]( s"Timeout timestamp ($timestampMs) cannot be earlier than the " + s"current watermark ($eventTimeWatermarkMs)") } - if (batchProcessingTimeMs != NO_TIMESTAMP) { - timeoutTimestamp = timestampMs - } else { - // This is being called in a batch query, hence no processing timestamp. - // Just ignore any attempts to set timeout. - } + timeoutTimestamp = timestampMs } @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") @@ -213,4 +196,23 @@ private[sql] class GroupStateImpl[S]( private[sql] object GroupStateImpl { // Value used represent the lack of valid timestamp as a long val NO_TIMESTAMP = -1L + + def createForStreaming[S]( + optionalValue: Option[S], + batchProcessingTimeMs: Long, + eventTimeWatermarkMs: Long, + timeoutConf: GroupStateTimeout, + hasTimedOut: Boolean): GroupStateImpl[S] = { + new GroupStateImpl[S]( + optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut) + } + + def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = { + new GroupStateImpl[Any]( + optionalValue = None, + batchProcessingTimeMs = NO_TIMESTAMP, + eventTimeWatermarkMs = NO_TIMESTAMP, + timeoutConf, + hasTimedOut = false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 10e91740eb92..6bb9408ce99e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -73,14 +73,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.hasRemoved === shouldBeRemoved) } + // === Tests for state in streaming queries === // Updating empty state - state = new GroupStateImpl[String](None) + state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = new GroupStateImpl[String](Some("2")) + state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -99,25 +100,34 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("GroupState - setTimeout**** with NoTimeout") { - for (initState <- Seq(None, Some(5))) { - // for different initial state - implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + for (initValue <- Seq(None, Some(5))) { + val states = Seq( + GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false), + GroupStateImpl.createForBatch(NoTimeout) + ) + for (state <- states) { + // for streaming queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } } } test("GroupState - setTimeout**** with ProcessingTimeTimeout") { - implicit var state: GroupStateImpl[Int] = null - - state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + // for streaming queries + var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( + None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) - assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state + assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.update(5) - assert(state.getTimeoutTimestamp === 1500) // does not change + assert(state.getTimeoutTimestamp === 1500) // does not change state.setTimeoutDuration(1000) assert(state.getTimeoutTimestamp === 2000) state.setTimeoutDuration("2 second") @@ -125,22 +135,38 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.remove() - assert(state.getTimeoutTimestamp === 3000) // does not change - state.setTimeoutDuration(500) // can still be set + assert(state.getTimeoutTimestamp === 3000) // does not change + state.setTimeoutDuration(500) // can still be set assert(state.getTimeoutTimestamp === 1500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + state.setTimeoutDuration(500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.update(5) + state.setTimeoutDuration(1000) + state.setTimeoutDuration("2 second") + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutDuration(500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } test("GroupState - setTimeout**** with EventTimeTimeout") { - implicit val state = new GroupStateImpl[Int]( - None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) + var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( + None, 1000, 1000, EventTimeTimeout, false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.setTimeoutTimestamp(5000) - assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state + assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state state.update(5) - assert(state.getTimeoutTimestamp === 5000) // does not change + assert(state.getTimeoutTimestamp === 5000) // does not change state.setTimeoutTimestamp(10000) assert(state.getTimeoutTimestamp === 10000) state.setTimeoutTimestamp(new Date(20000)) @@ -150,7 +176,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf state.remove() assert(state.getTimeoutTimestamp === 20000) state.setTimeoutTimestamp(5000) - assert(state.getTimeoutTimestamp === 5000) // can be set after removing state + assert(state.getTimeoutTimestamp === 5000) // can be set after removing state + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + state.setTimeoutTimestamp(5000) + + state.update(5) + state.setTimeoutTimestamp(10000) + state.setTimeoutTimestamp(new Date(20000)) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutTimestamp(5000) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) } @@ -165,7 +206,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } - state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } @@ -182,7 +224,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf state.setTimeoutDuration("1 month -1 day") } - state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutTimestamp(-10000) } @@ -211,23 +254,32 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("GroupState - hasTimedOut") { for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { + // for streaming queries for (initState <- Seq(None, Some(5))) { - val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + val state1 = GroupStateImpl.createForStreaming( + initState, 1000, 1000, timeoutConf, hasTimedOut = false) assert(state1.hasTimedOut === false) - val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + + val state2 = GroupStateImpl.createForStreaming( + initState, 1000, 1000, timeoutConf, hasTimedOut = true) assert(state2.hasTimedOut === true) } + + // for batch queries + assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false) } } test("GroupState - primitive type") { - var intState = new GroupStateImpl[Int](None) + var intState = GroupStateImpl.createForStreaming[Int]( + None, 1000, 1000, NoTimeout, hasTimedOut = false) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) - intState = new GroupStateImpl[Int](Some(10)) + intState = GroupStateImpl.createForStreaming[Int]( + Some(10), 1000, 1000, NoTimeout, hasTimedOut = false) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -243,7 +295,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" @@ -748,15 +799,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + // Test the following + // - no initial state + // - timeouts operations work, does not throw any error [SPARK-20792] + // - works with primitive state type + val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") + state.setTimeoutTimestamp(0, "1 hour") + state.update(10) (key, values.size) } checkAnswer( spark.createDataset(Seq("a", "a", "b")) .groupByKey(x => x) - .mapGroupsWithState(stateFunc) + .mapGroupsWithState(EventTimeTimeout)(stateFunc) .toDF, spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) }