Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
Expand All @@ -33,59 +34,35 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}

/**
* Physical operator for executing `FlatMapGroupsWithState`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nice if we could update this doc too.

*
* @param func function called on each group
* @param keyDeserializer used to extract the key object for each group.
* @param valueDeserializer used to extract the items in the iterator from an input row.
* @param initialStateDeserializer used to extract the state object from the initialState dataset
* @param groupingAttributes used to group the data
* @param dataAttributes used to read the data
* @param outputObjAttr Defines the output object
* @param stateEncoder used to serialize/deserialize state before calling `func`
* @param outputMode the output mode of `func`
* @param timeoutConf used to timeout groups that have not received data in a while
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermark event time watermark for the current batch
* @param initialState the user specified initial state
* @param hasInitialState indicates whether the initial state is provided or not
* @param child the physical plan for the underlying data
*/
case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
initialStateDeserializer: Expression,
groupingAttributes: Seq[Attribute],
initialStateGroupAttrs: Seq[Attribute],
dataAttributes: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
stateEncoder: ExpressionEncoder[Any],
stateFormatVersion: Int,
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
initialState: SparkPlan,
hasInitialState: Boolean,
child: SparkPlan
) extends BinaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport {

import FlatMapGroupsWithStateExecHelper._
trait FlatMapGroupsWithStateExecBase
extends StateStoreWriter with WatermarkSupport {
import GroupStateImpl._
import FlatMapGroupsWithStateExecHelper._

override def left: SparkPlan = child
protected val groupingAttributes: Seq[Attribute]

override def right: SparkPlan = initialState
protected val initialStateDeserializer: Expression
protected val initialStateGroupAttrs: Seq[Attribute]
protected val initialStateDataAttrs: Seq[Attribute]
protected val initialState: SparkPlan
protected val hasInitialState: Boolean

val stateInfo: Option[StatefulOperatorStateInfo]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stateInfo / eventTimeWatermark <= would they work if we change them to protected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember I exposed them for a reason at that time.. let me change this to protected and see if it passes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error is like this:

[error] /.../spark/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala:51:17: overriding method stateInfo in trait StatefulOperator of type => Option[org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo];
[error]  value stateInfo has weaker access privileges; it should be public
[error]   protected val stateInfo: Option[StatefulOperatorStateInfo]
[error]                 ^
[error] /.../spark/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala:57:17: overriding method eventTimeWatermark in trait WatermarkSupport of type => Option[Long];
[error]  value eventTimeWatermark has weaker access privileges; it should be public
[error]   protected val eventTimeWatermark: Option[Long]
[error]                 ^
[error] two errors found
[error] (sql / Compile / compileIncremental) Compilation failed
[error] Total time: 160 s (02:40), completed Sep 13, 2022 1:03:01 PM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK that was from another trait. Thanks for explanation.

protected val stateEncoder: ExpressionEncoder[Any]
protected val stateFormatVersion: Int
protected val outputMode: OutputMode
protected val timeoutConf: GroupStateTimeout
protected val batchTimestampMs: Option[Long]
val eventTimeWatermark: Option[Long]

private val isTimeoutEnabled = timeoutConf != NoTimeout
private val watermarkPresent = child.output.exists {
protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout
protected val watermarkPresent: Boolean = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
case _ => false
}

private[sql] val stateManager =
lazy val stateManager: StateManager =
createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)

/**
Expand Down Expand Up @@ -240,7 +217,7 @@ case class FlatMapGroupsWithStateExec(
stateManager.stateSchema,
numColsPrefixKey = 0,
stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value)
val processor = new InputProcessor(store)
val processor = createInputProcessor(store)
processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator))
}
} else {
Expand All @@ -252,31 +229,25 @@ case class FlatMapGroupsWithStateExec(
session.sqlContext.sessionState,
Some(session.sqlContext.streams.stateStoreCoordinator)
) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
val processor = new InputProcessor(store)
val processor = createInputProcessor(store)
processDataWithPartition(singleIterator, store, processor)
}
}
}

/** Helper class to update the state store */
class InputProcessor(store: StateStore) {
def createInputProcessor(store: StateStore): InputProcessor

// Converters for translating input keys, values, output data between rows and Java objects
private val getKeyObj =
ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
private val getValueObj =
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType)
abstract class InputProcessor(store: StateStore) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe good to have some doc for this abstract class?

private val getStateObj = if (hasInitialState) {
Some(ObjectOperator.deserializeRowToObject(initialStateDeserializer, initialStateDataAttrs))
} else {
None
}

// Metrics
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
private val numOutputRows = longMetric("numOutputRows")
private val numRemovedStateRows = longMetric("numRemovedStateRows")
protected val numUpdatedStateRows: SQLMetric = longMetric("numUpdatedStateRows")
protected val numOutputRows: SQLMetric = longMetric("numOutputRows")
protected val numRemovedStateRows: SQLMetric = longMetric("numRemovedStateRows")

/**
* For every group, get the key, values and corresponding state and call the function,
Expand Down Expand Up @@ -362,7 +333,74 @@ case class FlatMapGroupsWithStateExec(
* @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
* @param hasTimedOut Whether this function is being called for a key timeout
*/
private def callFunctionAndUpdateState(
protected def callFunctionAndUpdateState(
stateData: StateData,
valueRowIter: Iterator[InternalRow],
hasTimedOut: Boolean): Iterator[InternalRow]
}
}

/**
* Physical operator for executing `FlatMapGroupsWithState`
*
* @param func function called on each group
* @param keyDeserializer used to extract the key object for each group.
* @param valueDeserializer used to extract the items in the iterator from an input row.
* @param initialStateDeserializer used to extract the state object from the initialState dataset
* @param groupingAttributes used to group the data
* @param dataAttributes used to read the data
* @param outputObjAttr Defines the output object
* @param stateEncoder used to serialize/deserialize state before calling `func`
* @param outputMode the output mode of `func`
* @param timeoutConf used to timeout groups that have not received data in a while
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermark event time watermark for the current batch
* @param initialState the user specified initial state
* @param hasInitialState indicates whether the initial state is provided or not
* @param child the physical plan for the underlying data
*/
case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
initialStateDeserializer: Expression,
groupingAttributes: Seq[Attribute],
initialStateGroupAttrs: Seq[Attribute],
dataAttributes: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
stateEncoder: ExpressionEncoder[Any],
stateFormatVersion: Int,
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
initialState: SparkPlan,
hasInitialState: Boolean,
child: SparkPlan)
extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec {
import GroupStateImpl._
import FlatMapGroupsWithStateExecHelper._

override def left: SparkPlan = child

override def right: SparkPlan = initialState

override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec =
copy(child = newLeft, initialState = newRight)

override def createInputProcessor(
store: StateStore): InputProcessor = new InputProcessor(store) {
// Converters for translating input keys, values, output data between rows and Java objects
private val getKeyObj =
ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
private val getValueObj =
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType)

override protected def callFunctionAndUpdateState(
stateData: StateData,
valueRowIter: Iterator[InternalRow],
hasTimedOut: Boolean): Iterator[InternalRow] = {
Expand Down Expand Up @@ -405,10 +443,6 @@ case class FlatMapGroupsWithStateExec(
CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
}
}

override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec =
copy(child = newLeft, initialState = newRight)
}

object FlatMapGroupsWithStateExec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1733,7 +1733,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
val store = newStateStore()
val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec(
mapGroupsFunc, timeoutConf, currentBatchTimestamp)
val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store)
val inputProcessor = mapGroupsSparkPlan.createInputProcessor(store)
val stateManager = mapGroupsSparkPlan.stateManager
val key = intToRow(0)
// Prepare store with prior state configs
Expand Down