Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
package org.apache.spark.examples.streaming

import org.apache.spark.SparkConf
import org.apache.spark.HashPartitioner
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._

/**
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
* second.
* second starting with initial value of word count.
* Usage: StatefulNetworkWordCount <hostname> <port>
* <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive
* data.
Expand Down Expand Up @@ -51,11 +52,18 @@ object StatefulNetworkWordCount {
Some(currentCount + previousCount)
}

val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}

val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")

// Initial RDD input to updateStateByKey
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

// Create a NetworkInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(0), args(1).toInt)
Expand All @@ -64,7 +72,8 @@ object StatefulNetworkWordCount {

// Update the cumulative count using updateStateByKey
// This will give a Dstream made of state (which is the cumulative count of the words)
val stateDstream = wordDstream.updateStateByKey[Int](updateFunc)
val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
stateDstream.print()
ssc.start()
ssc.awaitTermination()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,25 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner)
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
* corresponding state key-value pair will be eliminated.
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
* DStream.
* @param initialRDD initial state value of each key.
* @tparam S State type
*/
def updateStateByKey[S](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this makes it different from the Scala api, and the reason is that the Java API does not expose the version of the API that uses iterators. And as far as I remember, that was not done because its tricky to write a function Java dealing with scala iterators and all.

Lets try to make the Java and Scala API consistent (minus the iterator based ones). So could you please add a Scala API equivalent to this, that is, (simple non-iterator function, partitioner, initial RDD)? Also unit test for that please :)

updateFunc: JFunction2[JList[V], Optional[S], Optional[S]],
partitioner: Partitioner,
initialRDD: JavaPairRDD[K, S]
): JavaPairDStream[K, S] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two space indent here.

implicit val cm: ClassTag[S] = fakeClassTag
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner, initialRDD)
}

/**
* Return a new DStream by applying a map function to the value of each key-value pairs in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,54 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)])
partitioner: Partitioner,
rememberPartitioner: Boolean
): DStream[(K, S)] = {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
* corresponding state key-value pair will be eliminated.
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
* DStream.
* @param initialRDD initial state value of each key.
* @tparam S State type
*/
def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S],
partitioner: Partitioner,
initialRDD: RDD[(K, S)]
): DStream[(K, S)] = {
val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}
updateStateByKey(newUpdateFunc, partitioner, true, initialRDD)
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
* corresponding state key-value pair will be eliminated. Note, that
* this function may generate a different a tuple with a different key
* than the input key. It is up to the developer to decide whether to
* remember the partitioner despite the key being changed.
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
* DStream
* @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs.
* @param initialRDD initial state value of each key.
* @tparam S State type
*/
def updateStateByKey[S: ClassTag](
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
rememberPartitioner: Boolean,
initialRDD: RDD[(K, S)]
): DStream[(K, S)] = {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner,
rememberPartitioner, Some(initialRDD))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
parent: DStream[(K, V)],
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
preservePartitioning: Boolean
preservePartitioning: Boolean,
initialRDD : Option[RDD[(K, S)]]
) extends DStream[(K, S)](parent.ssc) {

super.persist(StorageLevel.MEMORY_ONLY_SER)
Expand All @@ -41,6 +42,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](

override val mustCheckpoint = true

private [this] def computeUsingPreviousRDD (
parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = {
// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
val i = iterator.map(t => {
val itr = t._2._2.iterator
val headOption = if(itr.hasNext) Some(itr.next) else None
(t._1, t._2._1.toSeq, headOption)
})
updateFuncLocal(i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
}

override def compute(validTime: Time): Option[RDD[(K, S)]] = {

// Try to get the previous state RDD
Expand All @@ -51,25 +71,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
case Some(parentRDD) => { // If parent RDD exists, then compute as usual

// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
val i = iterator.map(t => {
val itr = t._2._2.iterator
val headOption = itr.hasNext match {
case true => Some(itr.next())
case false => None
}
(t._1, t._2._1.toSeq, headOption)
})
updateFuncLocal(i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
computeUsingPreviousRDD (parentRDD, prevStateRDD)
}
case None => { // If parent RDD does not exist

Expand All @@ -90,19 +92,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
case Some(parentRDD) => { // If parent RDD exists, then compute as usual
initialRDD match {
case None => {
// Define the function for the mapPartition operation on grouped RDD;
// first map the grouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator : Iterator[(K, Iterable[V])]) => {
updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))
}

// Define the function for the mapPartition operation on grouped RDD;
// first map the grouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => {
updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None)))
val groupedRDD = parentRDD.groupByKey (partitioner)
val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning)
// logDebug("Generating state RDD for time " + validTime + " (first)")
Some (sessionRDD)
}
case Some (initialStateRDD) => {
computeUsingPreviousRDD(parentRDD, initialStateRDD)
}
}

val groupedRDD = parentRDD.groupByKey(partitioner)
val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
// logDebug("Generating state RDD for time " + validTime + " (first)")
Some(sessionRDD)
}
case None => { // If parent RDD does not exist, then nothing to do!
// logDebug("Not generating state RDD (no previous state, no parent)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,15 +806,17 @@ public void testUnion() {
* Performs an order-invariant comparison of lists representing two RDD streams. This allows
* us to account for ordering variation within individual RDD's which occurs during windowing.
*/
public static <T extends Comparable<T>> void assertOrderInvariantEquals(
public static <T> void assertOrderInvariantEquals(
List<List<T>> expected, List<List<T>> actual) {
List<Set<T>> expectedSets = new ArrayList<Set<T>>();
for (List<T> list: expected) {
Collections.sort(list);
expectedSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
}
List<Set<T>> actualSets = new ArrayList<Set<T>>();
for (List<T> list: actual) {
Collections.sort(list);
actualSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
}
Assert.assertEquals(expected, actual);
Assert.assertEquals(expectedSets, actualSets);
}


Expand Down Expand Up @@ -1239,6 +1241,49 @@ public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
Assert.assertEquals(expected, result);
}

@SuppressWarnings("unchecked")
@Test
public void testUpdateStateByKeyWithInitial() {
List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;

List<Tuple2<String, Integer>> initial = Arrays.asList (
new Tuple2<String, Integer> ("california", 1),
new Tuple2<String, Integer> ("new york", 2));

JavaRDD<Tuple2<String, Integer>> tmpRDD = ssc.sparkContext().parallelize(initial);
JavaPairRDD<String, Integer> initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD);

List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
Arrays.asList(new Tuple2<String, Integer>("california", 5),
new Tuple2<String, Integer>("new york", 7)),
Arrays.asList(new Tuple2<String, Integer>("california", 15),
new Tuple2<String, Integer>("new york", 11)),
Arrays.asList(new Tuple2<String, Integer>("california", 15),
new Tuple2<String, Integer>("new york", 11)));

JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);

JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
@Override
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (do this if you have to make more changes): incorrect indentation. should be indented by more than the line with new Function2...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any other change to do. Actually I copied the previous method.

public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
int out = 0;
if (state.isPresent()) {
out = out + state.get();
}
for (Integer v: values) {
out = out + v;
}
return Optional.of(out);
}
}, new HashPartitioner(1), initialRDD);
JavaTestUtils.attachTestOutputStream(updated);
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);

assertOrderInvariantEquals(expected, result);
}

@SuppressWarnings("unchecked")
@Test
public void testReduceByKeyAndWindowWithInverse() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.{DStream, WindowedDStream}
import org.apache.spark.HashPartitioner

class BasicOperationsSuite extends TestSuiteBase {
test("map") {
Expand Down Expand Up @@ -350,6 +351,79 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - simple with initial value RDD") {
val initial = Seq(("a", 1), ("c", 2))

val inputData =
Seq(
Seq("a"),
Seq("a", "b"),
Seq("a", "b", "c"),
Seq("a", "b"),
Seq("a"),
Seq()
)

val outputData =
Seq(
Seq(("a", 2), ("c", 2)),
Seq(("a", 3), ("b", 1), ("c", 2)),
Seq(("a", 4), ("b", 2), ("c", 3)),
Seq(("a", 5), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3))
)

val updateStateOperation = (s: DStream[String]) => {
val initialRDD = s.context.sparkContext.makeRDD(initial)
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
Some(values.sum + state.getOrElse(0))
}
s.map(x => (x, 1)).updateStateByKey[Int](updateFunc,
new HashPartitioner (numInputPartitions), initialRDD)
}

testOperation(inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - with initial value RDD") {
val initial = Seq(("a", 1), ("c", 2))

val inputData =
Seq(
Seq("a"),
Seq("a", "b"),
Seq("a", "b", "c"),
Seq("a", "b"),
Seq("a"),
Seq()
)

val outputData =
Seq(
Seq(("a", 2), ("c", 2)),
Seq(("a", 3), ("b", 1), ("c", 2)),
Seq(("a", 4), ("b", 2), ("c", 3)),
Seq(("a", 5), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3))
)

val updateStateOperation = (s: DStream[String]) => {
val initialRDD = s.context.sparkContext.makeRDD(initial)
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
Some(values.sum + state.getOrElse(0))
}
val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}
s.map(x => (x, 1)).updateStateByKey[Int](newUpdateFunc,
new HashPartitioner (numInputPartitions), true, initialRDD)
}

testOperation(inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - object lifecycle") {
val inputData =
Seq(
Expand Down