-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-40411][SS] Refactor FlatMapGroupsWithStateExec to have a parent trait #37859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi | |
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.plans.physical.Distribution | ||
| import org.apache.spark.sql.execution._ | ||
| import org.apache.spark.sql.execution.metric.SQLMetric | ||
| import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ | ||
| import org.apache.spark.sql.execution.streaming.state._ | ||
| import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} | ||
|
|
@@ -33,59 +34,35 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} | |
|
|
||
| /** | ||
| * Physical operator for executing `FlatMapGroupsWithState` | ||
| * | ||
| * @param func function called on each group | ||
| * @param keyDeserializer used to extract the key object for each group. | ||
| * @param valueDeserializer used to extract the items in the iterator from an input row. | ||
| * @param initialStateDeserializer used to extract the state object from the initialState dataset | ||
| * @param groupingAttributes used to group the data | ||
| * @param dataAttributes used to read the data | ||
| * @param outputObjAttr Defines the output object | ||
| * @param stateEncoder used to serialize/deserialize state before calling `func` | ||
| * @param outputMode the output mode of `func` | ||
| * @param timeoutConf used to timeout groups that have not received data in a while | ||
| * @param batchTimestampMs processing timestamp of the current batch. | ||
| * @param eventTimeWatermark event time watermark for the current batch | ||
| * @param initialState the user specified initial state | ||
| * @param hasInitialState indicates whether the initial state is provided or not | ||
| * @param child the physical plan for the underlying data | ||
| */ | ||
| case class FlatMapGroupsWithStateExec( | ||
| func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], | ||
| keyDeserializer: Expression, | ||
| valueDeserializer: Expression, | ||
| initialStateDeserializer: Expression, | ||
| groupingAttributes: Seq[Attribute], | ||
| initialStateGroupAttrs: Seq[Attribute], | ||
| dataAttributes: Seq[Attribute], | ||
| initialStateDataAttrs: Seq[Attribute], | ||
| outputObjAttr: Attribute, | ||
| stateInfo: Option[StatefulOperatorStateInfo], | ||
| stateEncoder: ExpressionEncoder[Any], | ||
| stateFormatVersion: Int, | ||
| outputMode: OutputMode, | ||
| timeoutConf: GroupStateTimeout, | ||
| batchTimestampMs: Option[Long], | ||
| eventTimeWatermark: Option[Long], | ||
| initialState: SparkPlan, | ||
| hasInitialState: Boolean, | ||
| child: SparkPlan | ||
| ) extends BinaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { | ||
|
|
||
| import FlatMapGroupsWithStateExecHelper._ | ||
| trait FlatMapGroupsWithStateExecBase | ||
| extends StateStoreWriter with WatermarkSupport { | ||
| import GroupStateImpl._ | ||
| import FlatMapGroupsWithStateExecHelper._ | ||
|
|
||
| override def left: SparkPlan = child | ||
| protected val groupingAttributes: Seq[Attribute] | ||
|
|
||
| override def right: SparkPlan = initialState | ||
| protected val initialStateDeserializer: Expression | ||
| protected val initialStateGroupAttrs: Seq[Attribute] | ||
| protected val initialStateDataAttrs: Seq[Attribute] | ||
| protected val initialState: SparkPlan | ||
| protected val hasInitialState: Boolean | ||
|
|
||
| val stateInfo: Option[StatefulOperatorStateInfo] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember I exposed them for a reason at that time.. let me change this to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error is like this:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah OK that was from another trait. Thanks for explanation. |
||
| protected val stateEncoder: ExpressionEncoder[Any] | ||
| protected val stateFormatVersion: Int | ||
| protected val outputMode: OutputMode | ||
| protected val timeoutConf: GroupStateTimeout | ||
| protected val batchTimestampMs: Option[Long] | ||
| val eventTimeWatermark: Option[Long] | ||
|
|
||
| private val isTimeoutEnabled = timeoutConf != NoTimeout | ||
| private val watermarkPresent = child.output.exists { | ||
| protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout | ||
| protected val watermarkPresent: Boolean = child.output.exists { | ||
| case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true | ||
| case _ => false | ||
| } | ||
|
|
||
| private[sql] val stateManager = | ||
| lazy val stateManager: StateManager = | ||
| createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) | ||
|
|
||
| /** | ||
|
|
@@ -240,7 +217,7 @@ case class FlatMapGroupsWithStateExec( | |
| stateManager.stateSchema, | ||
| numColsPrefixKey = 0, | ||
| stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value) | ||
| val processor = new InputProcessor(store) | ||
| val processor = createInputProcessor(store) | ||
| processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator)) | ||
| } | ||
| } else { | ||
|
|
@@ -252,31 +229,25 @@ case class FlatMapGroupsWithStateExec( | |
| session.sqlContext.sessionState, | ||
| Some(session.sqlContext.streams.stateStoreCoordinator) | ||
| ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => | ||
| val processor = new InputProcessor(store) | ||
| val processor = createInputProcessor(store) | ||
| processDataWithPartition(singleIterator, store, processor) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** Helper class to update the state store */ | ||
| class InputProcessor(store: StateStore) { | ||
| def createInputProcessor(store: StateStore): InputProcessor | ||
|
|
||
| // Converters for translating input keys, values, output data between rows and Java objects | ||
| private val getKeyObj = | ||
| ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) | ||
| private val getValueObj = | ||
| ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) | ||
| private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) | ||
| abstract class InputProcessor(store: StateStore) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe good to have some doc for this abstract class? |
||
| private val getStateObj = if (hasInitialState) { | ||
| Some(ObjectOperator.deserializeRowToObject(initialStateDeserializer, initialStateDataAttrs)) | ||
| } else { | ||
| None | ||
| } | ||
|
|
||
| // Metrics | ||
| private val numUpdatedStateRows = longMetric("numUpdatedStateRows") | ||
| private val numOutputRows = longMetric("numOutputRows") | ||
| private val numRemovedStateRows = longMetric("numRemovedStateRows") | ||
| protected val numUpdatedStateRows: SQLMetric = longMetric("numUpdatedStateRows") | ||
| protected val numOutputRows: SQLMetric = longMetric("numOutputRows") | ||
| protected val numRemovedStateRows: SQLMetric = longMetric("numRemovedStateRows") | ||
|
|
||
| /** | ||
| * For every group, get the key, values and corresponding state and call the function, | ||
|
|
@@ -362,7 +333,74 @@ case class FlatMapGroupsWithStateExec( | |
| * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty | ||
| * @param hasTimedOut Whether this function is being called for a key timeout | ||
| */ | ||
| private def callFunctionAndUpdateState( | ||
| protected def callFunctionAndUpdateState( | ||
| stateData: StateData, | ||
| valueRowIter: Iterator[InternalRow], | ||
| hasTimedOut: Boolean): Iterator[InternalRow] | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Physical operator for executing `FlatMapGroupsWithState` | ||
| * | ||
| * @param func function called on each group | ||
| * @param keyDeserializer used to extract the key object for each group. | ||
| * @param valueDeserializer used to extract the items in the iterator from an input row. | ||
| * @param initialStateDeserializer used to extract the state object from the initialState dataset | ||
| * @param groupingAttributes used to group the data | ||
| * @param dataAttributes used to read the data | ||
| * @param outputObjAttr Defines the output object | ||
| * @param stateEncoder used to serialize/deserialize state before calling `func` | ||
| * @param outputMode the output mode of `func` | ||
| * @param timeoutConf used to timeout groups that have not received data in a while | ||
| * @param batchTimestampMs processing timestamp of the current batch. | ||
| * @param eventTimeWatermark event time watermark for the current batch | ||
| * @param initialState the user specified initial state | ||
| * @param hasInitialState indicates whether the initial state is provided or not | ||
| * @param child the physical plan for the underlying data | ||
| */ | ||
| case class FlatMapGroupsWithStateExec( | ||
| func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], | ||
| keyDeserializer: Expression, | ||
| valueDeserializer: Expression, | ||
| initialStateDeserializer: Expression, | ||
| groupingAttributes: Seq[Attribute], | ||
| initialStateGroupAttrs: Seq[Attribute], | ||
| dataAttributes: Seq[Attribute], | ||
| initialStateDataAttrs: Seq[Attribute], | ||
| outputObjAttr: Attribute, | ||
| stateInfo: Option[StatefulOperatorStateInfo], | ||
| stateEncoder: ExpressionEncoder[Any], | ||
| stateFormatVersion: Int, | ||
| outputMode: OutputMode, | ||
| timeoutConf: GroupStateTimeout, | ||
| batchTimestampMs: Option[Long], | ||
| eventTimeWatermark: Option[Long], | ||
| initialState: SparkPlan, | ||
| hasInitialState: Boolean, | ||
| child: SparkPlan) | ||
| extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec { | ||
| import GroupStateImpl._ | ||
| import FlatMapGroupsWithStateExecHelper._ | ||
|
|
||
| override def left: SparkPlan = child | ||
|
|
||
| override def right: SparkPlan = initialState | ||
|
|
||
| override protected def withNewChildrenInternal( | ||
| newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec = | ||
| copy(child = newLeft, initialState = newRight) | ||
|
|
||
| override def createInputProcessor( | ||
| store: StateStore): InputProcessor = new InputProcessor(store) { | ||
| // Converters for translating input keys, values, output data between rows and Java objects | ||
| private val getKeyObj = | ||
| ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) | ||
| private val getValueObj = | ||
| ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) | ||
| private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) | ||
|
|
||
| override protected def callFunctionAndUpdateState( | ||
| stateData: StateData, | ||
| valueRowIter: Iterator[InternalRow], | ||
| hasTimedOut: Boolean): Iterator[InternalRow] = { | ||
|
|
@@ -405,10 +443,6 @@ case class FlatMapGroupsWithStateExec( | |
| CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) | ||
| } | ||
| } | ||
|
|
||
| override protected def withNewChildrenInternal( | ||
| newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec = | ||
| copy(child = newLeft, initialState = newRight) | ||
| } | ||
|
|
||
| object FlatMapGroupsWithStateExec { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's nice if we could update this doc too.