diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 02ba1c2eed0f7..50f637f9762fb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -44,18 +44,6 @@ object StatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels() - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.sum - - val previousCount = state.getOrElse(0) - - Some(currentCount + previousCount) - } - - val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) - } - val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) @@ -72,8 +60,16 @@ object StatefulNetworkWordCount { // Update the cumulative count using updateStateByKey // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, - new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) + + val trackStateFunc = (word: String, one: Option[Int], state: State[Int]) => { + val sum = one.getOrElse(0) + state.getOrElse(0) + val output = (word, sum) + state.update(sum) + Some(output) + } + + val stateDstream = wordDstream.trackStateByKey( + TrackStateSpec(trackStateFunc).initialState(initialRDD)) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala new file mode 100644 index 0000000000000..d6ea2252c0547 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +/** + * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] and + * [[org.apache.spark.streaming.api.java.JavaPairDStream]]. + * {{{ + * + * }}} + */ +sealed abstract class State[S] { + + /** Whether the state already exists */ + def exists(): Boolean + + /** + * Get the state if it exists, otherwise wise it will throw an exception. + * Check with `exists()` whether the state exists or not before calling `get()`. + */ + def get(): S + + /** + * Update the state with a new value. Note that you cannot update the state if the state is + * timing out (that is, `isTimingOut() return true`, or if the state has already been removed by + * `remove()`. + */ + def update(newState: S): Unit + + /** Remove the state if it exists. */ + def remove(): Unit + + /** Is the state going to be timed out by the system after this batch interval */ + def isTimingOut(): Boolean + + @inline final def getOption(): Option[S] = Option(get()) + + /** Get the state if it exists, otherwise return the default value */ + @inline final def getOrElse[S1 >: S](default: => S1): S1 = { + if (exists) this.get else default + } + + @inline final override def toString() = getOption.map { _.toString }.getOrElse("") +} + +/** Internal implementation of the [[State]] interface */ +private[streaming] class StateImpl[S] extends State[S] { + + private var state: S = null.asInstanceOf[S] + private var defined: Boolean = true + private var timingOut: Boolean = false + private var updated: Boolean = false + private var removed: Boolean = false + + // ========= Public API ========= + def exists(): Boolean = { + defined + } + + def get(): S = { + state + } + + def update(newState: S): Unit = { + require(!removed, "Cannot update the state after it has been removed") + require(!timingOut, "Cannot update the state that is timing out") + state = newState + updated = true + } + + def isTimingOut(): Boolean = { + timingOut + } + + def remove(): Unit = { + require(!timingOut, "Cannot remove the state that is timing out") + removed = true + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved(): Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated(): Boolean = { + updated + } + + /** + * Internal method to update the state data and reset internal flags in `this`. + * This method allows `this` object to be reused across many state records. + */ + def wrap(optionalState: Option[S]): Unit = { + optionalState match { + case Some(newState) => + this.state = newState + defined = true + + case None => + this.state = null.asInstanceOf[S] + defined = false + } + timingOut = false + removed = false + updated = false + } + + /** + * Internal method to update the state data and reset internal flags in `this`. + * This method allows `this` object to be reused across many state records. + */ + def wrapTiminoutState(newState: S): Unit = { + this.state = newState + defined = true + timingOut = true + removed = false + updated = false + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala new file mode 100644 index 0000000000000..f0edcf2b9bfe6 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.{HashPartitioner, Partitioner} +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.rdd.RDD + + +/** + * Abstract class having all the specifications of DStream.trackStateByKey(). + * Use the `TrackStateSpec.create()` or `TrackStateSpec.create()` to create instances of this class. + * + * {{{ + * TrackStateSpec(trackingFunction) // in Scala + * TrackStateSpec.create(trackingFunction) // in Java + * }}} + */ +sealed abstract class TrackStateSpec[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag] + extends Serializable { + + def initialState(rdd: RDD[(K, S)]): this.type + def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type + + def numPartitions(numPartitions: Int): this.type + def partitioner(partitioner: Partitioner): this.type + + def timeout(interval: Duration): this.type +} + + +/** Builder object for creating instances of TrackStateSpec */ +object TrackStateSpec { + + def apply[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { + new TrackStateSpecImpl[K, V, S, T](trackingFunction) + } + + def create[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { + apply(trackingFunction) + } +} + + +/** Internal implementation of [[TrackStateSpec]] interface */ +private[streaming] +case class TrackStateSpecImpl[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + function: (K, Option[V], State[S]) => Option[T]) extends TrackStateSpec[K, V, S, T] { + + require(function != null) + + @volatile private var partitioner: Partitioner = null + @volatile private var initialStateRDD: RDD[(K, S)] = null + @volatile private var timeoutInterval: Duration = null + + + def initialState(rdd: RDD[(K, S)]): this.type = { + this.initialStateRDD = rdd + this + } + + def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + this.initialStateRDD = javaPairRDD.rdd + this + } + + + def numPartitions(numPartitions: Int): this.type = { + this.partitioner(new HashPartitioner(numPartitions)) + this + } + + def partitioner(partitioner: Partitioner): this.type = { + this.partitioner = partitioner + this + } + + def timeout(interval: Duration): this.type = { + this.timeoutInterval = interval + this + } + + // ================= Private Methods ================= + + private[streaming] def getFunction(): (K, Option[V], State[S]) => Option[T] = function + + private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) + + private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) + + private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) +} \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 71bec96d46c8d..bac511b89921b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,19 +24,18 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.streaming.{Duration, Time, TrackStateSpec, TrackStateSpecImpl} import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} +import org.apache.spark.{HashPartitioner, Partitioner} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. */ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) - extends Serializable -{ + extends Serializable { private[streaming] def ssc = self.ssc private[streaming] def sparkContext = self.context.sparkContext @@ -350,6 +349,16 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } + def trackStateByKey[S: ClassTag, T: ClassTag](spec: TrackStateSpec[K, V, S, T]): DStream[T] = { + new TrackeStateDStream[K, V, S, T]( + self, + spec.asInstanceOf[TrackStateSpecImpl[K, V, S, T]] + ).mapPartitions { partitionIter => + partitionIter.flatMap { _.emittedRecords } + } + } + + /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackeStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackeStateDStream.scala new file mode 100644 index 0000000000000..205d6c70ef231 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackeStateDStream.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import java.io.{IOException, ObjectOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.util.StateMap +import org.apache.spark.util.Utils + +private[streaming] case class TrackStateRDDRecord[K: ClassTag, S: ClassTag, T: ClassTag]( + stateMap: StateMap[K, S], emittedRecords: Seq[T]) + + +private[streaming] class TrackStateRDDPartition( + idx: Int, + @transient private var prevStateRDD: RDD[_], + @transient private var partitionedDataRDD: RDD[_]) extends Partition { + + private[dstream] var previousSessionRDDPartition: Partition = null + private[dstream] var partitionedDataRDDPartition: Partition = null + + override def index: Int = idx + override def hashCode(): Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { + // Update the reference to parent split at the time of task serialization + previousSessionRDDPartition = prevStateRDD.partitions(index) + partitionedDataRDDPartition = partitionedDataRDD.partitions(index) + oos.defaultWriteObject() + } +} + +private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + _sc: SparkContext, + private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], + private var partitionedDataRDD: RDD[(K, V)], + trackingFunction: (K, Option[V], State[S]) => Option[T], + currentTime: Long, timeoutThresholdTime: Option[Long] + ) extends RDD[TrackStateRDDRecord[K, S, T]]( + _sc, + List( + new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD), + new OneToOneDependency(partitionedDataRDD)) + ) { + + @volatile private var doFullScan = false + + require(partitionedDataRDD.partitioner.nonEmpty) + require(partitionedDataRDD.partitioner == prevStateRDD.partitioner) + + override val partitioner = prevStateRDD.partitioner + + override def checkpoint(): Unit = { + super.checkpoint() + doFullScan = true + } + + override def compute( + partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = { + + val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition] + val prevStateRDDIterator = prevStateRDD.iterator( + stateRDDPartition.previousSessionRDDPartition, context) + val dataIterator = partitionedDataRDD.iterator( + stateRDDPartition.partitionedDataRDDPartition, context) + if (!prevStateRDDIterator.hasNext) { + throw new SparkException(s"Could not find state map in previous state RDD") + } + + val newStateMap = prevStateRDDIterator.next().stateMap.copy() + val emittedRecords = new ArrayBuffer[T] + + val wrappedState = new StateImpl[S]() + + dataIterator.foreach { case (key, value) => + wrappedState.wrap(newStateMap.get(key)) + val emittedRecord = trackingFunction(key, Some(value), wrappedState) + if (wrappedState.isRemoved) { + newStateMap.remove(key) + } else if (wrappedState.isUpdated) { + newStateMap.put(key, wrappedState.get(), currentTime) + } + emittedRecords ++= emittedRecord + } + + if (doFullScan) { + if (timeoutThresholdTime.isDefined) { + newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => + wrappedState.wrapTiminoutState(state) + val emittedRecord = trackingFunction(key, None, wrappedState) + emittedRecords ++= emittedRecord + } + } + } + + Iterator(TrackStateRDDRecord(newStateMap, emittedRecords)) + } + + override protected def getPartitions: Array[Partition] = { + Array.tabulate(prevStateRDD.partitions.length) { i => + new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} + } + + override def clearDependencies() { + super.clearDependencies() + prevStateRDD = null + partitionedDataRDD = null + } +} + +private[streaming] object TrackStateRDD { + def createFromPairRDD[K: ClassTag, S: ClassTag, T: ClassTag]( + pairRDD: RDD[(K, S)], + partitioner: Partitioner, + updateTime: Long): RDD[TrackStateRDDRecord[K, S, T]] = { + + val createRecord = (iterator: Iterator[(K, S)]) => { + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime) } + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) + } + pairRDD.partitionBy(partitioner).mapPartitions[TrackStateRDDRecord[K, S, T]]( + createRecord, true) + } +} + + +// ----------------------------------------------- +// ---------------- SessionDStream --------------- +// ----------------------------------------------- + + +private[streaming] class TrackeStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T]) + extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) { + + persist(StorageLevel.MEMORY_ONLY) + + private val partitioner = spec.getPartitioner().getOrElse( + new HashPartitioner(ssc.sc.defaultParallelism)) + + private val trackingFunction = spec.getFunction() + + override def slideDuration: Duration = parent.slideDuration + + override def dependencies: List[DStream[_]] = List(parent) + + override val mustCheckpoint = true + + /** Method that generates a RDD for the given time */ + override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, T]]] = { + val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { + TrackStateRDD.createFromPairRDD[K, S, T]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime.milliseconds + ) + } + val newDataRDD = parent.getOrCompute(validTime).get + val partitionedDataRDD = newDataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds + } + + Some(new TrackStateRDD( + ssc.sparkContext, prevStateRDD, partitionedDataRDD, + trackingFunction, validTime.milliseconds, timeoutThresholdTime)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala new file mode 100644 index 0000000000000..d19ed1b31eabc --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ +import org.apache.spark.util.collection.OpenHashMap + +/** Internal interface for defining the map that keeps track of sessions. */ +private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { + + /** Get the state for a key if it exists */ + def get(key: K): Option[S] + + /** Get all the keys and states whose updated time is older than the given threshold time */ + def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] + + /** Get all the keys and states in this map. */ + def getAll(): Iterator[(K, S, Long)] + + /** Add or update state */ + def put(key: K, state: S, updatedTime: Long): Unit + + /** Remove a key */ + def remove(key: K): Unit + + /** + * Shallow copy `this` map to create a new state map. + * Updates to the new map should not mutate `this` map. + */ + def copy(): StateMap[K, S] + + def toDebugString(): String = toString() +} + +/** Companion object for [[StateMap]], with utility methods */ +private[streaming] object StateMap { + def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] + + def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { + val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", + DELTA_CHAIN_LENGTH_THRESHOLD) + new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) + } +} + +/** Specific implementation of SessionStore interface representing an empty map */ +private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { + override def put(key: K, session: S, updateTime: Long): Unit = ??? + override def get(key: K): Option[S] = None + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty + override def copy(): StateMap[K, S] = new EmptyStateMap[K, S] + override def remove(key: K): Unit = { } + override def getAll(): Iterator[(K, S, Long)] = Iterator.empty + override def toDebugString(): String = "" +} + + + +/** Implementation of StateMap based on Spark's OpenHashMap */ +private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( + @transient @volatile private var parentStateMap: StateMap[K, S], + initialCapacity: Int = 64, + deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + ) extends StateMap[K, S] { self => + + def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + new EmptyStateMap[K, S], + initialCapacity = initialCapacity, + deltaChainThreshold = deltaChainThreshold) + + def this(deltaChainThreshold: Int) = this( + initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + + def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) + + @transient @volatile private var deltaMap = + new OpenHashMap[K, StateInfo[S]](initialCapacity) + + /** Get the session data if it exists */ + override def get(key: K): Option[S] = { + val stateInfo = deltaMap(key) + if (stateInfo != null && !stateInfo.deleted) { + Some(stateInfo.data) + } else { + parentStateMap.get(key) + } + } + + /** Get all the keys and states whose updated time is older than the give threshold time */ + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = { + val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.flatMap { case (key, stateInfo) => + if (! stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime) { + Some((key, stateInfo.data, stateInfo.updateTime)) + } else None + } + oldStates ++ updatedStates + } + + /** Get all the keys and states in this map. */ + override def getAll(): Iterator[(K, S, Long)] = { + + val oldStates = parentStateMap.getAll().filter { case (key, _, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Add or update state */ + override def put(key: K, state: S, updateTime: Long): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.update(state, updateTime) + } else { + deltaMap.update(key, new StateInfo(state, updateTime)) + } + } + + /** Remove a state */ + override def remove(key: K): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.markDeleted() + } else { + val newInfo = new StateInfo[S](deleted = true) + deltaMap.update(key, newInfo) + } + } + + /** + * Shallow copy the map to create a new session store. Updates to the new map + * should not mutate `this` map. + */ + override def copy(): StateMap[K, S] = { + new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) + } + + def shouldCompact: Boolean = { + deltaChainLength >= deltaChainThreshold + } + + def deltaChainLength: Int = parentStateMap match { + case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1 + case _ => 0 + } + + def approxSize: Int = deltaMap.size + { + parentStateMap match { + case s: OpenHashMapBasedStateMap[_, _] => s.approxSize + case _ => 0 + } + } + + override def toDebugString(): String = { + val tabs = if (deltaChainLength > 0) { + (" " * (deltaChainLength - 1)) +"+--- " + } else "" + parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") + } + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + + outputStream.defaultWriteObject() + + // Write the deltaMap + outputStream.writeInt(deltaMap.size) + val deltaMapIterator = deltaMap.iterator + var deltaMapCount = 0 + while (deltaMapIterator.hasNext) { + deltaMapCount += 1 + val (key, stateInfo) = deltaMapIterator.next() + outputStream.writeObject(key) + outputStream.writeObject(stateInfo) + } + assert(deltaMapCount == deltaMap.size) + + // Write the parentStateMap while consolidating + val doCompaction = shouldCompact + val newParentSessionStore = if (doCompaction) { + val initCapacity = if (approxSize > 0) approxSize else 64 + new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, deltaChainThreshold) + } else { null } + + val iterOfActiveSessions = parentStateMap.getAll() + + var parentSessionCount = 0 + + outputStream.writeInt(approxSize) + + while(iterOfActiveSessions.hasNext) { + parentSessionCount += 1 + + val (key, state, updateTime) = iterOfActiveSessions.next() + outputStream.writeObject(key) + outputStream.writeObject(state) + outputStream.writeLong(updateTime) + + if (doCompaction) { + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + val limiterObj = new LimitMarker(parentSessionCount) + outputStream.writeObject(limiterObj) + if (doCompaction) { + parentStateMap = newParentSessionStore + } + } + + private def readObject(inputStream: ObjectInputStream): Unit = { + inputStream.defaultReadObject() + + val deltaMapSize = inputStream.readInt() + deltaMap = new OpenHashMap[K, StateInfo[S]]() + var deltaMapCount = 0 + while (deltaMapCount < deltaMapSize) { + val key = inputStream.readObject().asInstanceOf[K] + val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]] + deltaMap.update(key, sessionInfo) + deltaMapCount += 1 + } + + val parentSessionStoreSizeHint = inputStream.readInt() + val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( + initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + + var parentSessionLoopDone = false + while(!parentSessionLoopDone) { + val obj = inputStream.readObject() + if (obj.isInstanceOf[LimitMarker]) { + parentSessionLoopDone = true + val expectedCount = obj.asInstanceOf[LimitMarker].num + assert(expectedCount == newParentSessionStore.deltaMap.size) + } else { + val key = obj.asInstanceOf[K] + val state = inputStream.readObject().asInstanceOf[S] + val updateTime = inputStream.readLong() + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + parentStateMap = newParentSessionStore + } +} + +private[streaming] object OpenHashMapBasedStateMap { + + case class StateInfo[S]( + var data: S = null.asInstanceOf[S], + var updateTime: Long = -1, + var deleted: Boolean = false) { + + def markDeleted(): Unit = { + deleted = true + } + + def update(newData: S, newUpdateTime: Long): Unit = { + data = newData + updateTime = newUpdateTime + deleted = false + } + } + + class LimitMarker(val num: Int) extends Serializable + + val DELTA_CHAIN_LENGTH_THRESHOLD = 20 +} \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 255376807c957..310cf2d467006 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -22,13 +22,11 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.language.existentials import scala.reflect.ClassTag -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} +import org.apache.spark.streaming.dstream.{TrackStateSpec, DStream, WindowedDStream} import org.apache.spark.util.{Clock, ManualClock} -import org.apache.spark.HashPartitioner +import org.apache.spark.{HashPartitioner, SparkConf, SparkException} class BasicOperationsSuite extends TestSuiteBase { test("map") { @@ -631,6 +629,42 @@ class BasicOperationsSuite extends TestSuiteBase { } } + + test("trackStateByKey with emitted states") { + val inputData = + Seq( + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3)), + Seq(("a", 5)), + Seq() + ) + + val trackStateOp = (s: DStream[String]) => { + val updateFunc = (key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOrElse(0) + val output = (key, sum) + state.update(sum) + Some(output) + } + s.map(x => (x, 1)).trackStateByKey(TrackStateSpec(updateFunc)) + } + + testOperation(inputData, trackStateOp, outputData, true) + } + + /** Test cleanup of RDDs in DStream metadata */ def runCleanupTest[T: ClassTag]( conf2: SparkConf, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala new file mode 100644 index 0000000000000..bd1936e9ad9dc --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.util.{OpenHashMapBasedStateMap, StateMap} +import org.apache.spark.util.Utils + +class StateMapSuite extends SparkFunSuite { + + test("OpenHashMapBasedStateMap - basic operations") { + val map = new OpenHashMapBasedStateMap[Int, Int]() + + map.put(1, 100, 10) + assert(map.get(1) === Some(100)) + assert(map.get(2) === None) + map.put(2, 200, 20) + assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20))) + + map.remove(1) + assert(map.get(1) === None) + assert(map.getAll().toSet === Set((2, 200, 20))) + } + + test("OpenHashMapBasedStateMap - basic operations after copy") { + val parentMap = new OpenHashMapBasedStateMap[Int, Int]() + parentMap.put(1, 100, 1) + parentMap.put(2, 200, 2) + parentMap.remove(1) + + val map = parentMap.copy() + assert(map.getAll().toSet === Set((2, 200, 2))) + + // Add new items + map.put(3, 300, 3) + map.put(4, 400, 4) + assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Remove items + map.remove(4) // remove item added to this map + map.remove(2) // remove item remove in parent map + assert(map.getAll().toSet === Set((3, 300, 3))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Update items + map.put(1, 1000, 100) // update item removed in parent map + map.put(2, 2000, 200) // update item added in parent map and removed in this map + map.put(3, 3000, 300) // update item added in this map + map.put(4, 4000, 400) // update item removed in this map + + assert(map.getAll().toSet === + Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + } + + test("OpenHashMapBasedStateMap - serializing and deserializing") { + val map1 = new OpenHashMapBasedStateMap[Int, Int]() + map1.put(1, 100, 1) + map1.put(2, 200, 2) + + val map2 = map1.copy() + map2.put(3, 300, 3) + map2.put(4, 400, 4) + + val map3 = map2.copy() + map3.put(3, 600, 3) + map3.remove(2) + + // Do not test compaction + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + + val map3_ = Utils.deserialize[StateMap[Int, Int]](Utils.serialize(map3), Thread.currentThread().getContextClassLoader) + assert(map3_.getAll().toSet === map3.getAll().toSet) + assert(map3.getAll().forall { case (key, state, _) => map3_.get(key) === Some(state)}) + } + + test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { + val targetDeltaLength = 10 + val deltaChainThreshold = 5 + + var map = new OpenHashMapBasedStateMap[Int, Int]( + deltaChainThreshold = deltaChainThreshold) + + for(i <- 1 to targetDeltaLength) { + map.put(Random.nextInt(), Random.nextInt(), Random.nextLong()) + map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + } + assert(map.deltaChainLength > deltaChainThreshold) + assert(map.shouldCompact === true) + + val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assert(deser_map.deltaChainLength < deltaChainThreshold) + assert(deser_map.shouldCompact === false) + assert(deser_map.getAll().toSet === map.getAll().toSet) + assert(map.getAll().forall { case (key, state, _) => deser_map.get(key) === Some(state)}) + + } +} \ No newline at end of file