Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private[state] class HDFSBackedStateStoreProvider(
trait STATE
case object UPDATING extends STATE
case object COMMITTED extends STATE
case object CANCELLED extends STATE
case object ABORTED extends STATE

private val newVersion = version + 1
private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
Expand All @@ -94,15 +94,14 @@ private[state] class HDFSBackedStateStoreProvider(

override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id

/**
* Update the value of a key using the value generated by the update function.
* @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
* versions of the store data.
*/
override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = {
verify(state == UPDATING, "Cannot update after already committed or cancelled")
val oldValueOption = Option(mapToUpdate.get(key))
val value = updateFunc(oldValueOption)
override def get(key: UnsafeRow): Option[UnsafeRow] = {
Option(mapToUpdate.get(key))
}

override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
verify(state == UPDATING, "Cannot remove after already committed or cancelled")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove -> put


val isNewKey = !mapToUpdate.containsKey(key)
mapToUpdate.put(key, value)

Option(allUpdates.get(key)) match {
Expand All @@ -115,8 +114,7 @@ private[state] class HDFSBackedStateStoreProvider(
case None =>
// There was no prior update, so mark this as added or updated according to its presence
// in previous version.
val update =
if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value)
val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value)
allUpdates.put(key, update)
}
writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
Expand Down Expand Up @@ -148,7 +146,7 @@ private[state] class HDFSBackedStateStoreProvider(

/** Commit all the updates that have been made to the store, and return the new version. */
override def commit(): Long = {
verify(state == UPDATING, "Cannot commit again after already committed or cancelled")
verify(state == UPDATING, "Cannot commit after already committed or cancelled")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cancelled -> aborted


try {
finalizeDeltaFile(tempDeltaFileStream)
Expand All @@ -164,8 +162,8 @@ private[state] class HDFSBackedStateStoreProvider(
}

/** Cancel all the updates made on this store. This store will not be usable any more. */
override def cancel(): Unit = {
state = CANCELLED
override def abort(): Unit = {
state = ABORTED
if (tempDeltaFileStream != null) {
tempDeltaFileStream.close()
}
Expand All @@ -176,8 +174,8 @@ private[state] class HDFSBackedStateStoreProvider(
}

/**
* Get an iterator of all the store data. This can be called only after committing the
* updates.
* Get an iterator of all the store data.
* This can be called only after committing all the updates made in the current thread.
*/
override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
verify(state == COMMITTED, "Cannot get iterator of store data before comitting")
Expand All @@ -186,7 +184,7 @@ private[state] class HDFSBackedStateStoreProvider(

/**
* Get an iterator of all the updates made to the store in the current version.
* This can be called only after committing the updates.
* This can be called only after committing all the updates made in the current thread.
*/
override def updates(): Iterator[StoreUpdate] = {
verify(state == COMMITTED, "Cannot get iterator of updates before committing")
Expand All @@ -196,7 +194,7 @@ private[state] class HDFSBackedStateStoreProvider(
/**
* Whether all updates have been committed
*/
override def hasCommitted: Boolean = {
override private[state] def hasCommitted: Boolean = {
state == COMMITTED
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,11 @@ trait StateStore {
/** Version of the data in this store before committing updates. */
def version: Long

/**
* Update the value of a key using the value generated by the update function.
* @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
* versions of the store data.
*/
def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit
/** Get the current value of a key. */
def get(key: UnsafeRow): Option[UnsafeRow]

/** Put a new value for a key. */
def put(key: UnsafeRow, value: UnsafeRow)

/**
* Remove keys that match the following condition.
Expand All @@ -65,24 +64,24 @@ trait StateStore {
def commit(): Long

/** Cancel all the updates that have been made to the store. */
def cancel(): Unit
def abort(): Unit

/**
* Iterator of store data after a set of updates have been committed.
* This can be called only after commitUpdates() has been called in the current thread.
* This can be called only after committing all the updates made in the current thread.
*/
def iterator(): Iterator[(UnsafeRow, UnsafeRow)]

/**
* Iterator of the updates that have been committed.
* This can be called only after commitUpdates() has been called in the current thread.
* This can be called only after committing all the updates made in the current thread.
*/
def updates(): Iterator[StoreUpdate]

/**
* Whether all updates have been committed
*/
def hasCommitted: Boolean
private[state] def hasCommitted: Boolean
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](

override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
var store: StateStore = null

Utils.tryWithSafeFinally {
val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
store = StateStore.get(
storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
val inputIter = dataRDD.iterator(partition, ctxt)
val outputIter = storeUpdateFunction(store, inputIter)
assert(store.hasCommitted)
outputIter
} {
if (store != null) store.cancel()
}
val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
store = StateStore.get(
storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
val inputIter = dataRDD.iterator(partition, ctxt)
storeUpdateFunction(store, inputIter)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The finally was removed to allow commits to be done lazily.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
import java.io.File
import java.nio.file.Files

import scala.tools.nsc.interpreter.Completion
import scala.util.Random

import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
Expand All @@ -33,7 +34,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.Utils
import org.apache.spark.util.{CompletionIterator, Utils}

class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {

Expand All @@ -54,62 +55,91 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
}

test("versioning and immutability") {
quietly {
withSpark(new SparkContext(sparkConf)) { sc =>
implicit val sqlContet = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val increment = (store: StateStore, iter: Iterator[String]) => {
iter.foreach { s =>
store.update(
stringToRow(s), oldRow => {
val oldValue = oldRow.map(rowToInt).getOrElse(0)
intToRow(oldValue + 1)
})
}
store.commit()
store.iterator().map(rowsToStringInt)
}
val opId = 0
val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 0, keySchema, valueSchema)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
withSpark(new SparkContext(sparkConf)) { sc =>
implicit val sqlContet = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 0, keySchema, valueSchema)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))

// Generate next version of stores
val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 1, keySchema, valueSchema)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Generate next version of stores
val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 1, keySchema, valueSchema)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))

// Make sure the previous RDD still has the same data.
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
}
// Make sure the previous RDD still has the same data.
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
}
}

test("recovering from files") {
quietly {
val opId = 0
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString

def makeStoreRDD(
sc: SparkContext,
seq: Seq[String],
storeVersion: Int): RDD[(String, Int)] = {
implicit val sqlContext = new SQLContext(sc)
makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
increment, path, opId, storeVersion, keySchema, valueSchema)
}

def makeStoreRDD(
sc: SparkContext,
seq: Seq[String],
storeVersion: Int): RDD[(String, Int)] = {
implicit val sqlContext = new SQLContext(sc)
makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
increment, path, opId, storeVersion, keySchema, valueSchema)
// Generate RDDs and state store data
withSpark(new SparkContext(sparkConf)) { sc =>
for (i <- 1 to 20) {
require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
}
}

// Generate RDDs and state store data
withSpark(new SparkContext(sparkConf)) { sc =>
for (i <- 1 to 20) {
require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
// With a new context, try using the earlier state store data
withSpark(new SparkContext(sparkConf)) { sc =>
assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
}
}

test("usage with iterators - only gets and only puts") {
withSpark(new SparkContext(sparkConf)) { sc =>
implicit val sqlContext = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0

// Returns an iterator of the incremented value made into the store
def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = {
val resIterator = iter.map { s =>
val key = stringToRow(s)
val oldValue = store.get(key).map(rowToInt).getOrElse(0)
val newValue = oldValue + 1
store.put(key, intToRow(newValue))
(s, newValue)
}
CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, {
store.commit()
})
}

// With a new context, try using the earlier state store data
withSpark(new SparkContext(sparkConf)) { sc =>
assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
def iteratorOfGets(
store: StateStore,
iter: Iterator[String]): Iterator[(String, Option[Int])] = {
iter.map { s =>
val key = stringToRow(s)
val value = store.get(key).map(rowToInt)
(s, value)
}
}

val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionWithStateStore(
iteratorOfGets, path, opId, storeVersion = 0, keySchema, valueSchema)
assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))

val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
iteratorOfPuts, path, opId, storeVersion = 0, keySchema, valueSchema)
assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))

val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionWithStateStore(
iteratorOfGets, path, opId, storeVersion = 1, keySchema, valueSchema)
assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
}
}

Expand Down Expand Up @@ -152,17 +182,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc =>
implicit val sqlContet = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val increment = (store: StateStore, iter: Iterator[String]) => {
iter.foreach { s =>
store.update(
stringToRow(s), oldRow => {
val oldValue = oldRow.map(rowToInt).getOrElse(0)
intToRow(oldValue + 1)
})
}
store.commit()
store.iterator().map(rowsToStringInt)
}
val opId = 0
val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 0, keySchema, valueSchema)
Expand All @@ -185,11 +204,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn

private val increment = (store: StateStore, iter: Iterator[String]) => {
iter.foreach { s =>
store.update(
stringToRow(s), oldRow => {
val oldValue = oldRow.map(rowToInt).getOrElse(0)
intToRow(oldValue + 1)
})
val key = stringToRow(s)
val oldValue = store.get(key).map(rowToInt).getOrElse(0)
store.put(key, intToRow(oldValue + 1))
}
store.commit()
store.iterator().map(rowsToStringInt)
Expand Down
Loading