diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 3ff539b9ef32b..790a652f21124 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} @@ -33,59 +34,35 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} /** * Physical operator for executing `FlatMapGroupsWithState` - * - * @param func function called on each group - * @param keyDeserializer used to extract the key object for each group. - * @param valueDeserializer used to extract the items in the iterator from an input row. - * @param initialStateDeserializer used to extract the state object from the initialState dataset - * @param groupingAttributes used to group the data - * @param dataAttributes used to read the data - * @param outputObjAttr Defines the output object - * @param stateEncoder used to serialize/deserialize state before calling `func` - * @param outputMode the output mode of `func` - * @param timeoutConf used to timeout groups that have not received data in a while - * @param batchTimestampMs processing timestamp of the current batch. - * @param eventTimeWatermark event time watermark for the current batch - * @param initialState the user specified initial state - * @param hasInitialState indicates whether the initial state is provided or not - * @param child the physical plan for the underlying data */ -case class FlatMapGroupsWithStateExec( - func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], - keyDeserializer: Expression, - valueDeserializer: Expression, - initialStateDeserializer: Expression, - groupingAttributes: Seq[Attribute], - initialStateGroupAttrs: Seq[Attribute], - dataAttributes: Seq[Attribute], - initialStateDataAttrs: Seq[Attribute], - outputObjAttr: Attribute, - stateInfo: Option[StatefulOperatorStateInfo], - stateEncoder: ExpressionEncoder[Any], - stateFormatVersion: Int, - outputMode: OutputMode, - timeoutConf: GroupStateTimeout, - batchTimestampMs: Option[Long], - eventTimeWatermark: Option[Long], - initialState: SparkPlan, - hasInitialState: Boolean, - child: SparkPlan - ) extends BinaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { - - import FlatMapGroupsWithStateExecHelper._ +trait FlatMapGroupsWithStateExecBase + extends StateStoreWriter with WatermarkSupport { import GroupStateImpl._ + import FlatMapGroupsWithStateExecHelper._ - override def left: SparkPlan = child + protected val groupingAttributes: Seq[Attribute] - override def right: SparkPlan = initialState + protected val initialStateDeserializer: Expression + protected val initialStateGroupAttrs: Seq[Attribute] + protected val initialStateDataAttrs: Seq[Attribute] + protected val initialState: SparkPlan + protected val hasInitialState: Boolean + + val stateInfo: Option[StatefulOperatorStateInfo] + protected val stateEncoder: ExpressionEncoder[Any] + protected val stateFormatVersion: Int + protected val outputMode: OutputMode + protected val timeoutConf: GroupStateTimeout + protected val batchTimestampMs: Option[Long] + val eventTimeWatermark: Option[Long] - private val isTimeoutEnabled = timeoutConf != NoTimeout - private val watermarkPresent = child.output.exists { + protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout + protected val watermarkPresent: Boolean = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } - private[sql] val stateManager = + lazy val stateManager: StateManager = createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) /** @@ -240,7 +217,7 @@ case class FlatMapGroupsWithStateExec( stateManager.stateSchema, numColsPrefixKey = 0, stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value) - val processor = new InputProcessor(store) + val processor = createInputProcessor(store) processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator)) } } else { @@ -252,21 +229,15 @@ case class FlatMapGroupsWithStateExec( session.sqlContext.sessionState, Some(session.sqlContext.streams.stateStoreCoordinator) ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => - val processor = new InputProcessor(store) + val processor = createInputProcessor(store) processDataWithPartition(singleIterator, store, processor) } } } - /** Helper class to update the state store */ - class InputProcessor(store: StateStore) { + def createInputProcessor(store: StateStore): InputProcessor - // Converters for translating input keys, values, output data between rows and Java objects - private val getKeyObj = - ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - private val getValueObj = - ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) + abstract class InputProcessor(store: StateStore) { private val getStateObj = if (hasInitialState) { Some(ObjectOperator.deserializeRowToObject(initialStateDeserializer, initialStateDataAttrs)) } else { @@ -274,9 +245,9 @@ case class FlatMapGroupsWithStateExec( } // Metrics - private val numUpdatedStateRows = longMetric("numUpdatedStateRows") - private val numOutputRows = longMetric("numOutputRows") - private val numRemovedStateRows = longMetric("numRemovedStateRows") + protected val numUpdatedStateRows: SQLMetric = longMetric("numUpdatedStateRows") + protected val numOutputRows: SQLMetric = longMetric("numOutputRows") + protected val numRemovedStateRows: SQLMetric = longMetric("numRemovedStateRows") /** * For every group, get the key, values and corresponding state and call the function, @@ -362,7 +333,74 @@ case class FlatMapGroupsWithStateExec( * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty * @param hasTimedOut Whether this function is being called for a key timeout */ - private def callFunctionAndUpdateState( + protected def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] + } +} + +/** + * Physical operator for executing `FlatMapGroupsWithState` + * + * @param func function called on each group + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param initialStateDeserializer used to extract the state object from the initialState dataset + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr Defines the output object + * @param stateEncoder used to serialize/deserialize state before calling `func` + * @param outputMode the output mode of `func` + * @param timeoutConf used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermark event time watermark for the current batch + * @param initialState the user specified initial state + * @param hasInitialState indicates whether the initial state is provided or not + * @param child the physical plan for the underlying data + */ +case class FlatMapGroupsWithStateExec( + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + initialStateDeserializer: Expression, + groupingAttributes: Seq[Attribute], + initialStateGroupAttrs: Seq[Attribute], + dataAttributes: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + outputObjAttr: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + stateEncoder: ExpressionEncoder[Any], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + initialState: SparkPlan, + hasInitialState: Boolean, + child: SparkPlan) + extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec { + import GroupStateImpl._ + import FlatMapGroupsWithStateExecHelper._ + + override def left: SparkPlan = child + + override def right: SparkPlan = initialState + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec = + copy(child = newLeft, initialState = newRight) + + override def createInputProcessor( + store: StateStore): InputProcessor = new InputProcessor(store) { + // Converters for translating input keys, values, output data between rows and Java objects + private val getKeyObj = + ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + private val getValueObj = + ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) + + override protected def callFunctionAndUpdateState( stateData: StateData, valueRowIter: Iterator[InternalRow], hasTimedOut: Boolean): Iterator[InternalRow] = { @@ -405,10 +443,6 @@ case class FlatMapGroupsWithStateExec( CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } } - - override protected def withNewChildrenInternal( - newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec = - copy(child = newLeft, initialState = newRight) } object FlatMapGroupsWithStateExec { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 9d34ceea8dd47..b7c9aa4178090 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1733,7 +1733,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val inputProcessor = mapGroupsSparkPlan.createInputProcessor(store) val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs