Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ case class FlatMapGroupsWithStateExec(
outputIterator,
{
store.commit()
longMetric("numTotalStateRows") += store.numKeys()
setStoreMetrics(store)
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,10 @@ trait ProgressReporter extends Logging {
if (lastExecution == null) return Nil
// lastExecution could belong to one of the previous triggers if `!hasNewData`.
// Walking the plan again should be inexpensive.
val stateNodes = lastExecution.executedPlan.collect {
case p if p.isInstanceOf[StateStoreWriter] => p
}
stateNodes.map { node =>
val numRowsUpdated = if (hasNewData) {
node.metrics.get("numUpdatedStateRows").map(_.value).getOrElse(0L)
} else {
0L
}
new StateOperatorProgress(
numRowsTotal = node.metrics.get("numTotalStateRows").map(_.value).getOrElse(0L),
numRowsUpdated = numRowsUpdated)
lastExecution.executedPlan.collect {
case p if p.isInstanceOf[StateStoreWriter] =>
val progress = p.asInstanceOf[StateStoreWriter].getProgress()
if (hasNewData) progress else progress.copy(newNumRowsUpdated = 0)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.io.LZ4CompressionCodec
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
import org.apache.spark.util.{SizeEstimator, Utils}


/**
Expand Down Expand Up @@ -172,7 +172,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}
}

override def numKeys(): Long = mapToUpdate.size()
override def metrics: StateStoreMetrics = {
StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate), Map.empty)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a flag for this? SizeEstimator.estimate will be very slow when there are a lot of states, because it scans all objects using reflection.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me do some tests to understand how long it will take. For arrays it will just sample, so it should not take that long.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Takes < 0.5 ms with a state store with 5 million elements

}

/**
* Whether all updates have been committed
Expand Down Expand Up @@ -230,6 +232,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
loadedMaps.values.foreach(_.clear())
}

override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = {
Nil
}

override def toString(): String = {
s"HDFSStateStoreProvider[" +
s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,33 @@ trait StateStore {

def iterator(): Iterator[UnsafeRowPair]

/** Number of keys in the state store */
def numKeys(): Long
/** Current metrics of the state store */
def metrics: StateStoreMetrics

/**
* Whether all updates have been committed
*/
def hasCommitted: Boolean
}

/**
* Metrics reported by a state store
* @param numKeys Number of keys in the state store
* @param memoryUsedBytes Memory used by the state store
* @param customMetrics Custom implementation-specific metrics
* The metrics reported through this must have the same `name` as those
* reported by `StateStoreProvider.customMetrics`.
*/
case class StateStoreMetrics(
numKeys: Long,
memoryUsedBytes: Long,
customMetrics: Map[StateStoreCustomMetric, Long])

/**
* Name and description of custom implementation-specific metrics that a
* state store may wish to expose.
*/
case class StateStoreCustomMetric(name: String, desc: String)

/**
* Trait representing a provider that provide [[StateStore]] instances representing
Expand Down Expand Up @@ -158,22 +176,36 @@ trait StateStoreProvider {

/** Optional method for providers to allow for background maintenance (e.g. compactions) */
def doMaintenance(): Unit = { }

/**
* Optional custom metrics that the implementation may want to report.
* @note The StateStore objects created by this provider must report the same custom metrics
* (specifically, same names) through `StateStore.metrics`.
*/
def supportedCustomMetrics: Seq[StateStoreCustomMetric] = Nil
}

object StateStoreProvider {

/**
* Return a instance of the given provider class name. The instance will not be initialized.
*/
def create(providerClassName: String): StateStoreProvider = {
val providerClass = Utils.classForName(providerClassName)
providerClass.newInstance().asInstanceOf[StateStoreProvider]
}

/**
* Return a provider instance of the given provider class.
* The instance will be already initialized.
* Return a instance of the required provider, initialized with the given configurations.
*/
def instantiate(
def createAndInit(
stateStoreId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
indexOrdinal: Option[Int], // for sorting the data
storeConf: StateStoreConf,
hadoopConf: Configuration): StateStoreProvider = {
val providerClass = Utils.classForName(storeConf.providerClass)
val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider]
val provider = create(storeConf.providerClass)
provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
provider
}
Expand Down Expand Up @@ -298,7 +330,7 @@ object StateStore extends Logging {
startMaintenanceIfNeeded()
val provider = loadedProviders.getOrElseUpdate(
storeProviderId,
StateStoreProvider.instantiate(
StateStoreProvider.createAndInit(
storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
)
reportActiveStoreInstance(storeProviderId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming
import java.util.UUID
import java.util.concurrent.TimeUnit._

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
Expand All @@ -29,9 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
import org.apache.spark.util.{CompletionIterator, NextIterator}

Expand Down Expand Up @@ -73,8 +75,21 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
"numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"),
"allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to update rows"),
"allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to remove rows"),
"commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes")
)
"commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes"),
"stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by state")
) ++ stateStoreCustomMetrics

/**
* Get the progress made by this stateful operator after execution. This should be called in
* the driver after this SparkPlan has been executed and metrics have been updated.
*/
def getProgress(): StateOperatorProgress = {
new StateOperatorProgress(
numRowsTotal = longMetric("numTotalStateRows").value,
numRowsUpdated = longMetric("numUpdatedStateRows").value,
memoryUsedBytes = longMetric("stateMemory").value,
numPartitions = this.sqlContext.conf.numShufflePartitions)
}

/** Records the duration of running `body` for the next query progress update. */
protected def timeTakenMs(body: => Unit): Long = {
Expand All @@ -83,6 +98,26 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
val endTime = System.nanoTime()
math.max(NANOSECONDS.toMillis(endTime - startTime), 0)
}

/**
* Set the SQL metrics related to the state store.
* This should be called in that task after the store has been updated.
*/
protected def setStoreMetrics(store: StateStore): Unit = {

val storeMetrics = store.metrics
longMetric("numTotalStateRows") += storeMetrics.numKeys
longMetric("stateMemory") += storeMetrics.memoryUsedBytes
storeMetrics.customMetrics.foreach { case (metric, value) =>
longMetric(metric.name) += value
}
}

private def stateStoreCustomMetrics: Map[String, SQLMetric] = {
val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass)
provider.supportedCustomMetrics.map { m =>
m.name -> SQLMetrics.createTimingMetric(sparkContext, m.desc) }.toMap
}
}

/** An operator that supports watermark. */
Expand Down Expand Up @@ -197,7 +232,6 @@ case class StateStoreSaveExec(
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
val numOutputRows = longMetric("numOutputRows")
val numTotalStateRows = longMetric("numTotalStateRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
Expand All @@ -218,7 +252,7 @@ case class StateStoreSaveExec(
commitTimeMs += timeTakenMs {
store.commit()
}
numTotalStateRows += store.numKeys()
setStoreMetrics(store)
store.iterator().map { rowPair =>
numOutputRows += 1
rowPair.value
Expand Down Expand Up @@ -261,7 +295,7 @@ case class StateStoreSaveExec(
override protected def close(): Unit = {
allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs)
commitTimeMs += timeTakenMs { store.commit() }
numTotalStateRows += store.numKeys()
setStoreMetrics(store)
}
}

Expand All @@ -285,7 +319,7 @@ case class StateStoreSaveExec(
// Remove old aggregates if watermark specified
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
numTotalStateRows += store.numKeys()
setStoreMetrics(store)
false
} else {
true
Expand Down Expand Up @@ -368,7 +402,7 @@ case class StreamingDeduplicateExec(
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
numTotalStateRows += store.numKeys()
setStoreMetrics(store)
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,25 @@ import org.apache.spark.annotation.InterfaceStability
@InterfaceStability.Evolving
class StateOperatorProgress private[sql](
val numRowsTotal: Long,
val numRowsUpdated: Long) extends Serializable {
val numRowsUpdated: Long,
val memoryUsedBytes: Long,
val numPartitions: Long
) extends Serializable {

/** The compact JSON representation of this progress. */
def json: String = compact(render(jsonValue))

/** The pretty (i.e. indented) JSON representation of this progress. */
def prettyJson: String = pretty(render(jsonValue))

private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress =
new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes, numPartitions)

private[sql] def jsonValue: JValue = {
("numRowsTotal" -> JInt(numRowsTotal)) ~
("numRowsUpdated" -> JInt(numRowsUpdated))
("numRowsUpdated" -> JInt(numRowsUpdated)) ~
("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~
("numPartitions" -> JInt(numPartitions))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
import java.io.{File, IOException}
import java.net.URI
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand Down Expand Up @@ -184,6 +185,15 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
}
}

test("reports memory usage") {
val provider = newStoreProvider()
val store = provider.getStore(0)
val noDataMemoryUsed = store.metrics.memoryUsedBytes
put(store, "a", 1)
store.commit()
assert(store.metrics.memoryUsedBytes > noDataMemoryUsed)
}

test("StateStore.get") {
quietly {
val dir = newDir()
Expand Down Expand Up @@ -554,22 +564,22 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
assert(!store.hasCommitted)
assert(get(store, "a") === None)
assert(store.iterator().isEmpty)
assert(store.numKeys() === 0)
assert(store.metrics.numKeys === 0)

// Verify state after updating
put(store, "a", 1)
assert(get(store, "a") === Some(1))
assert(store.numKeys() === 1)
assert(store.metrics.numKeys === 1)

assert(store.iterator().nonEmpty)
assert(getLatestData(provider).isEmpty)

// Make updates, commit and then verify state
put(store, "b", 2)
put(store, "aa", 3)
assert(store.numKeys() === 3)
assert(store.metrics.numKeys === 3)
remove(store, _.startsWith("a"))
assert(store.numKeys() === 1)
assert(store.metrics.numKeys === 1)
assert(store.commit() === 1)

assert(store.hasCommitted)
Expand All @@ -587,9 +597,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
// New updates to the reloaded store with new version, and does not change old version
val reloadedProvider = newStoreProvider(store.id)
val reloadedStore = reloadedProvider.getStore(1)
assert(reloadedStore.numKeys() === 1)
assert(reloadedStore.metrics.numKeys === 1)
put(reloadedStore, "c", 4)
assert(reloadedStore.numKeys() === 2)
assert(reloadedStore.metrics.numKeys === 2)
assert(reloadedStore.commit() === 2)
assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types.{DataType, IntegerType}
Expand Down Expand Up @@ -1077,7 +1077,7 @@ object FlatMapGroupsWithStateSuite {
override def abort(): Unit = { }
override def id: StateStoreId = null
override def version: Long = 0
override def numKeys(): Long = map.size
override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty)
override def hasCommitted: Boolean = true
}
}
Loading