diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index e4a515d203cc..f81b3d7241e7 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -472,6 +472,7 @@ def __hash__(self): "pyspark.sql.tests.pandas.test_pandas_udf_typehints", "pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations", "pyspark.sql.tests.pandas.test_pandas_udf_window", + "pyspark.sql.tests.test_pandas_sqlmetrics", "pyspark.sql.tests.test_readwriter", "pyspark.sql.tests.test_serde", "pyspark.sql.tests.test_session", diff --git a/docs/web-ui.md b/docs/web-ui.md index d3356ec5a43f..e228d7fe2a98 100644 --- a/docs/web-ui.md +++ b/docs/web-ui.md @@ -406,6 +406,8 @@ Here is the list of SQL metrics: time to build hash map the time spent on building hash map ShuffledHashJoin task commit time the time spent on committing the output of a task after the writes succeed any write operation on a file-based table job commit time the time spent on committing the output of a job after the writes succeed any write operation on a file-based table + data sent to Python workers the number of bytes of serialized data sent to the Python workers ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas + data returned from Python workers the number of bytes of serialized data received back from the Python workers ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas ## Structured Streaming Tab diff --git a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py new file mode 100644 index 000000000000..d182bafd8b54 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from typing import cast + +from pyspark.sql.functions import pandas_udf +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), +) +class PandasSQLMetrics(ReusedSQLTestCase): + def test_pandas_sql_metrics_basic(self): + # SPARK-34265: Instrument Python UDFs using SQL metrics + + python_sql_metrics = [ + "data sent to Python workers", + "data returned from Python workers", + "number of output rows", + ] + + @pandas_udf("long") + def test_pandas(col1): + return col1 * col1 + + self.spark.range(10).select(test_pandas("id")).collect() + + statusStore = self.spark._jsparkSession.sharedState().statusStore() + lastExecId = statusStore.executionsList().last().executionId() + executionMetrics = statusStore.execution(lastExecId).get().metrics().mkString() + + for metric in python_sql_metrics: + self.assertIn(metric, executionMetrics) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_sqlmetrics import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 2f85149ee8e1..6a8b197742d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -46,7 +46,7 @@ case class AggregateInPandasExec( udfExpressions: Seq[PythonUDF], resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with PythonSQLMetrics { override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -163,7 +163,8 @@ case class AggregateInPandasExec( argOffsets, aggInputSchema, sessionLocalTimeZone, - pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index bd8c72029dcb..f3531668c8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA import org.apache.spark.sql.execution.streaming.GroupStateImpl @@ -58,7 +59,8 @@ class ApplyInPandasWithStatePythonRunner( stateEncoder: ExpressionEncoder[Row], keySchema: StructType, outputSchema: StructType, - stateValueSchema: StructType) + stateValueSchema: StructType, + val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) with PythonArrowInput[InType] with PythonArrowOutput[OutType] { @@ -116,6 +118,7 @@ class ApplyInPandasWithStatePythonRunner( val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) while (inputIterator.hasNext) { + val startData = dataOut.size() val (keyRow, groupState, dataIter) = inputIterator.next() assert(dataIter.hasNext, "should have at least one data row!") w.startNewGroup(keyRow, groupState) @@ -126,6 +129,8 @@ class ApplyInPandasWithStatePythonRunner( } w.finalizeGroup() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } w.finalizeData() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 096712cf9352..b11dd4947af6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -61,7 +61,7 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int) */ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan, evalType: Int) - extends EvalPythonExec { + extends EvalPythonExec with PythonSQLMetrics { private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone @@ -85,7 +85,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] argOffsets, schema, sessionLocalTimeZone, - pythonRunnerConf).compute(batchIter, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(batchIter, context.partitionId(), context) columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 8467feb91d14..dbafc444281e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -32,7 +33,8 @@ class ArrowPythonRunner( argOffsets: Array[Array[Int]], protected override val schema: StructType, protected override val timeZoneId: String, - protected override val workerConf: Map[String, String]) + protected override val workerConf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) with BasicPythonArrowInput with BasicPythonArrowOutput { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 10f7966b93d1..ca7ca2e2f80a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{StructField, StructType} * A physical plan that evaluates a [[PythonUDF]] */ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan) - extends EvalPythonExec { + extends EvalPythonExec with PythonSQLMetrics { protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -77,7 +77,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets) + val outputIterator = + new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler @@ -94,6 +95,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => + pythonMetrics("pythonNumRowsReceived") += 1 if (udfs.length == 1) { // fast path for single UDF mutableRow(0) = fromJava(result) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 2661896ececc..1df9f37188a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -27,6 +27,7 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -45,7 +46,8 @@ class CoGroupedArrowPythonRunner( leftSchema: StructType, rightSchema: StructType, timeZoneId: String, - conf: Map[String, String]) + conf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets) with BasicPythonArrowOutput { @@ -77,10 +79,14 @@ class CoGroupedArrowPythonRunner( // For each we first send the number of dataframes in each group then send // first df, then send second df. End of data is marked by sending 0. while (inputIterator.hasNext) { + val startData = dataOut.size() dataOut.writeInt(2) val (nextLeft, nextRight) = inputIterator.next() writeGroup(nextLeft, leftSchema, dataOut, "left") writeGroup(nextRight, rightSchema, dataOut, "right") + + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } dataOut.writeInt(0) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index b39787b12a48..629df51e18ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -54,7 +54,7 @@ case class FlatMapCoGroupsInPandasExec( output: Seq[Attribute], left: SparkPlan, right: SparkPlan) - extends SparkPlan with BinaryExecNode { + extends SparkPlan with BinaryExecNode with PythonSQLMetrics { private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -77,7 +77,6 @@ case class FlatMapCoGroupsInPandasExec( } override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) @@ -97,7 +96,8 @@ case class FlatMapCoGroupsInPandasExec( StructType.fromAttributes(leftDedup), StructType.fromAttributes(rightDedup), sessionLocalTimeZone, - pythonRunnerConf) + pythonRunnerConf, + pythonMetrics) executePython(data, output, runner) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index f0e815e966e7..271ccdb6b271 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -50,7 +50,7 @@ case class FlatMapGroupsInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan) - extends SparkPlan with UnaryExecNode { + extends SparkPlan with UnaryExecNode with PythonSQLMetrics { private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -89,7 +89,8 @@ case class FlatMapGroupsInPandasExec( Array(argOffsets), StructType.fromAttributes(dedupAttributes), sessionLocalTimeZone, - pythonRunnerConf) + pythonRunnerConf, + pythonMetrics) executePython(data, output, runner) }} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 09123344c2e2..3b096f07241f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -62,7 +62,8 @@ case class FlatMapGroupsInPandasWithStateExec( timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], eventTimeWatermark: Option[Long], - child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + child: SparkPlan) + extends UnaryExecNode with PythonSQLMetrics with FlatMapGroupsWithStateExecBase { // TODO(SPARK-40444): Add the support of initial state. override protected val initialStateDeserializer: Expression = null @@ -166,7 +167,8 @@ case class FlatMapGroupsInPandasWithStateExec( stateEncoder.asInstanceOf[ExpressionEncoder[Row]], groupingAttributes.toStructType, outAttributes.toStructType, - stateType) + stateType, + pythonMetrics) val context = TaskContext.get() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index d25c13835407..450891c69483 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} * This is somewhat similar with [[FlatMapGroupsInPandasExec]] and * `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow` */ -trait MapInBatchExec extends UnaryExecNode { +trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { protected val func: Expression protected val pythonEvalType: Int @@ -75,7 +75,8 @@ trait MapInBatchExec extends UnaryExecNode { argOffsets, StructType(StructField("struct", outputTypes) :: Nil), sessionLocalTimeZone, - pythonRunnerConf).compute(batchIter, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(batchIter, context.partitionId(), context) val unsafeProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index bf66791183ec..5a0541d11cbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils @@ -41,6 +42,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val timeZoneId: String + protected def pythonMetrics: Map[String, SQLMetric] + protected def writeIteratorToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, @@ -115,6 +118,7 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In val arrowWriter = ArrowWriter.create(root) while (inputIterator.hasNext) { + val startData = dataOut.size() val nextBatch = inputIterator.next() while (nextBatch.hasNext) { @@ -124,6 +128,8 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In arrowWriter.finish() writer.writeBatch() arrowWriter.reset() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 339f114539c2..c12c690f776a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths} +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} @@ -37,6 +38,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column */ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => + protected def pythonMetrics: Map[String, SQLMetric] + protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT @@ -82,10 +85,15 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ } try { if (reader != null && batchLoaded) { + val bytesReadStart = reader.bytesRead() batchLoaded = reader.loadNextBatch() if (batchLoaded) { val batch = new ColumnarBatch(vectors) + val rowCount = root.getRowCount batch.setNumRows(root.getRowCount) + val bytesReadEnd = reader.bytesRead() + pythonMetrics("pythonNumRowsReceived") += rowCount + pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart deserializeColumnarBatch(batch, schema) } else { reader.close(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala new file mode 100644 index 000000000000..a748c1bc1008 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetrics + +private[sql] trait PythonSQLMetrics { self: SparkPlan => + + val pythonMetrics = Map( + "pythonDataSent" -> SQLMetrics.createSizeMetric(sparkContext, + "data sent to Python workers"), + "pythonDataReceived" -> SQLMetrics.createSizeMetric(sparkContext, + "data returned from Python workers"), + "pythonNumRowsReceived" -> SQLMetrics.createMetric(sparkContext, + "number of output rows") + ) + + override lazy val metrics = pythonMetrics +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index d1109d251c28..09e06b55df3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark._ import org.apache.spark.api.python._ +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf /** @@ -31,7 +32,8 @@ import org.apache.spark.sql.internal.SQLConf class PythonUDFRunner( funcs: Seq[ChainedPythonFunctions], evalType: Int, - argOffsets: Array[Array[Int]]) + argOffsets: Array[Array[Int]], + pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[Array[Byte], Array[Byte]]( funcs, evalType, argOffsets) { @@ -50,8 +52,13 @@ class PythonUDFRunner( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val startData = dataOut.size() + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } } } @@ -77,6 +84,7 @@ class PythonUDFRunner( case length if length > 0 => val obj = new Array[Byte](length) stream.readFully(obj) + pythonMetrics("pythonDataReceived") += length obj case 0 => Array.emptyByteArray case SpecialLengths.TIMING_DATA => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index ccb1ed92525d..dcaffed89cca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -84,7 +84,7 @@ case class WindowInPandasExec( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan) - extends WindowExecBase { + extends WindowExecBase with PythonSQLMetrics { /** * Helper functions and data structures for window bounds @@ -375,7 +375,8 @@ case class WindowInPandasExec( argOffsets, pythonInputSchema, sessionLocalTimeZone, - pythonRunnerConf).compute(pythonInput, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 2b8fc6515618..b540f9f00939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.python.PythonSQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ @@ -93,7 +94,7 @@ trait StateStoreReader extends StatefulOperator { } /** An operator that writes to a StateStore. */ -trait StateStoreWriter extends StatefulOperator { self: SparkPlan => +trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: SparkPlan => override lazy val metrics = statefulOperatorCustomMetrics ++ Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), @@ -109,7 +110,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => "numShufflePartitions" -> SQLMetrics.createMetric(sparkContext, "number of shuffle partitions"), "numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext, "number of state store instances") - ) ++ stateStoreCustomMetrics + ) ++ stateStoreCustomMetrics ++ pythonMetrics /** * Get the progress made by this stateful operator after execution. This should be called in diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 70784c20a8eb..7850b2d79b04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -84,4 +84,23 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { checkAnswer(actual, expected) } + + test("SPARK-34265: Instrument Python UDF execution using SQL Metrics") { + + val pythonSQLMetrics = List( + "data sent to Python workers", + "data returned from Python workers", + "number of output rows") + + val df = base.groupBy(pythonTestUDF(base("a") + 1)) + .agg(pythonTestUDF(pythonTestUDF(base("a") + 1))) + df.count() + + val statusStore = spark.sharedState.statusStore + val lastExecId = statusStore.executionsList.last.executionId + val executionMetrics = statusStore.execution(lastExecId).get.metrics.mkString + for (metric <- pythonSQLMetrics) { + assert(executionMetrics.contains(metric)) + } + } }