diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index e4a515d203cc..f81b3d7241e7 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -472,6 +472,7 @@ def __hash__(self):
"pyspark.sql.tests.pandas.test_pandas_udf_typehints",
"pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations",
"pyspark.sql.tests.pandas.test_pandas_udf_window",
+ "pyspark.sql.tests.test_pandas_sqlmetrics",
"pyspark.sql.tests.test_readwriter",
"pyspark.sql.tests.test_serde",
"pyspark.sql.tests.test_session",
diff --git a/docs/web-ui.md b/docs/web-ui.md
index d3356ec5a43f..e228d7fe2a98 100644
--- a/docs/web-ui.md
+++ b/docs/web-ui.md
@@ -406,6 +406,8 @@ Here is the list of SQL metrics:
time to build hash map | the time spent on building hash map | ShuffledHashJoin |
task commit time | the time spent on committing the output of a task after the writes succeed | any write operation on a file-based table |
job commit time | the time spent on committing the output of a job after the writes succeed | any write operation on a file-based table |
+ data sent to Python workers | the number of bytes of serialized data sent to the Python workers | ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas |
+ data returned from Python workers | the number of bytes of serialized data received back from the Python workers | ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas |
## Structured Streaming Tab
diff --git a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
new file mode 100644
index 000000000000..d182bafd8b54
--- /dev/null
+++ b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from typing import cast
+
+from pyspark.sql.functions import pandas_udf
+from pyspark.testing.sqlutils import (
+ ReusedSQLTestCase,
+ have_pandas,
+ have_pyarrow,
+ pandas_requirement_message,
+ pyarrow_requirement_message,
+)
+
+
+@unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ cast(str, pandas_requirement_message or pyarrow_requirement_message),
+)
+class PandasSQLMetrics(ReusedSQLTestCase):
+ def test_pandas_sql_metrics_basic(self):
+ # SPARK-34265: Instrument Python UDFs using SQL metrics
+
+ python_sql_metrics = [
+ "data sent to Python workers",
+ "data returned from Python workers",
+ "number of output rows",
+ ]
+
+ @pandas_udf("long")
+ def test_pandas(col1):
+ return col1 * col1
+
+ self.spark.range(10).select(test_pandas("id")).collect()
+
+ statusStore = self.spark._jsparkSession.sharedState().statusStore()
+ lastExecId = statusStore.executionsList().last().executionId()
+ executionMetrics = statusStore.execution(lastExecId).get().metrics().mkString()
+
+ for metric in python_sql_metrics:
+ self.assertIn(metric, executionMetrics)
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_pandas_sqlmetrics import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index 2f85149ee8e1..6a8b197742d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -46,7 +46,7 @@ case class AggregateInPandasExec(
udfExpressions: Seq[PythonUDF],
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends UnaryExecNode {
+ extends UnaryExecNode with PythonSQLMetrics {
override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -163,7 +163,8 @@ case class AggregateInPandasExec(
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
- pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)
+ pythonRunnerConf,
+ pythonMetrics).compute(projectedRowIter, context.partitionId(), context)
val joinedAttributes =
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index bd8c72029dcb..f3531668c8e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
import org.apache.spark.sql.execution.streaming.GroupStateImpl
@@ -58,7 +59,8 @@ class ApplyInPandasWithStatePythonRunner(
stateEncoder: ExpressionEncoder[Row],
keySchema: StructType,
outputSchema: StructType,
- stateValueSchema: StructType)
+ stateValueSchema: StructType,
+ val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
with PythonArrowInput[InType]
with PythonArrowOutput[OutType] {
@@ -116,6 +118,7 @@ class ApplyInPandasWithStatePythonRunner(
val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch)
while (inputIterator.hasNext) {
+ val startData = dataOut.size()
val (keyRow, groupState, dataIter) = inputIterator.next()
assert(dataIter.hasNext, "should have at least one data row!")
w.startNewGroup(keyRow, groupState)
@@ -126,6 +129,8 @@ class ApplyInPandasWithStatePythonRunner(
}
w.finalizeGroup()
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
}
w.finalizeData()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 096712cf9352..b11dd4947af6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -61,7 +61,7 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int)
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan,
evalType: Int)
- extends EvalPythonExec {
+ extends EvalPythonExec with PythonSQLMetrics {
private val batchSize = conf.arrowMaxRecordsPerBatch
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
@@ -85,7 +85,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
argOffsets,
schema,
sessionLocalTimeZone,
- pythonRunnerConf).compute(batchIter, context.partitionId(), context)
+ pythonRunnerConf,
+ pythonMetrics).compute(batchIter, context.partitionId(), context)
columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 8467feb91d14..dbafc444281e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -32,7 +33,8 @@ class ArrowPythonRunner(
argOffsets: Array[Array[Int]],
protected override val schema: StructType,
protected override val timeZoneId: String,
- protected override val workerConf: Map[String, String])
+ protected override val workerConf: Map[String, String],
+ val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets)
with BasicPythonArrowInput
with BasicPythonArrowOutput {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 10f7966b93d1..ca7ca2e2f80a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
* A physical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
- extends EvalPythonExec {
+ extends EvalPythonExec with PythonSQLMetrics {
protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
@@ -77,7 +77,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
}.grouped(100).map(x => pickle.dumps(x.toArray))
// Output iterator for results from Python.
- val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
+ val outputIterator =
+ new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics)
.compute(inputIterator, context.partitionId(), context)
val unpickle = new Unpickler
@@ -94,6 +95,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
+ pythonMetrics("pythonNumRowsReceived") += 1
if (udfs.length == 1) {
// fast path for single UDF
mutableRow(0) = fromJava(result)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 2661896ececc..1df9f37188a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
@@ -45,7 +46,8 @@ class CoGroupedArrowPythonRunner(
leftSchema: StructType,
rightSchema: StructType,
timeZoneId: String,
- conf: Map[String, String])
+ conf: Map[String, String],
+ val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[
(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets)
with BasicPythonArrowOutput {
@@ -77,10 +79,14 @@ class CoGroupedArrowPythonRunner(
// For each we first send the number of dataframes in each group then send
// first df, then send second df. End of data is marked by sending 0.
while (inputIterator.hasNext) {
+ val startData = dataOut.size()
dataOut.writeInt(2)
val (nextLeft, nextRight) = inputIterator.next()
writeGroup(nextLeft, leftSchema, dataOut, "left")
writeGroup(nextRight, rightSchema, dataOut, "right")
+
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
}
dataOut.writeInt(0)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
index b39787b12a48..629df51e18ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
@@ -54,7 +54,7 @@ case class FlatMapCoGroupsInPandasExec(
output: Seq[Attribute],
left: SparkPlan,
right: SparkPlan)
- extends SparkPlan with BinaryExecNode {
+ extends SparkPlan with BinaryExecNode with PythonSQLMetrics {
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
@@ -77,7 +77,6 @@ case class FlatMapCoGroupsInPandasExec(
}
override protected def doExecute(): RDD[InternalRow] = {
-
val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup)
@@ -97,7 +96,8 @@ case class FlatMapCoGroupsInPandasExec(
StructType.fromAttributes(leftDedup),
StructType.fromAttributes(rightDedup),
sessionLocalTimeZone,
- pythonRunnerConf)
+ pythonRunnerConf,
+ pythonMetrics)
executePython(data, output, runner)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index f0e815e966e7..271ccdb6b271 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -50,7 +50,7 @@ case class FlatMapGroupsInPandasExec(
func: Expression,
output: Seq[Attribute],
child: SparkPlan)
- extends SparkPlan with UnaryExecNode {
+ extends SparkPlan with UnaryExecNode with PythonSQLMetrics {
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
@@ -89,7 +89,8 @@ case class FlatMapGroupsInPandasExec(
Array(argOffsets),
StructType.fromAttributes(dedupAttributes),
sessionLocalTimeZone,
- pythonRunnerConf)
+ pythonRunnerConf,
+ pythonMetrics)
executePython(data, output, runner)
}}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
index 09123344c2e2..3b096f07241f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -62,7 +62,8 @@ case class FlatMapGroupsInPandasWithStateExec(
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
- child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {
+ child: SparkPlan)
+ extends UnaryExecNode with PythonSQLMetrics with FlatMapGroupsWithStateExecBase {
// TODO(SPARK-40444): Add the support of initial state.
override protected val initialStateDeserializer: Expression = null
@@ -166,7 +167,8 @@ case class FlatMapGroupsInPandasWithStateExec(
stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
groupingAttributes.toStructType,
outAttributes.toStructType,
- stateType)
+ stateType,
+ pythonMetrics)
val context = TaskContext.get()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index d25c13835407..450891c69483 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
* This is somewhat similar with [[FlatMapGroupsInPandasExec]] and
* `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow`
*/
-trait MapInBatchExec extends UnaryExecNode {
+trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
protected val func: Expression
protected val pythonEvalType: Int
@@ -75,7 +75,8 @@ trait MapInBatchExec extends UnaryExecNode {
argOffsets,
StructType(StructField("struct", outputTypes) :: Nil),
sessionLocalTimeZone,
- pythonRunnerConf).compute(batchIter, context.partitionId(), context)
+ pythonRunnerConf,
+ pythonMetrics).compute(batchIter, context.partitionId(), context)
val unsafeProj = UnsafeProjection.create(output, output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index bf66791183ec..5a0541d11cbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, PythonRDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils
@@ -41,6 +42,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
protected val timeZoneId: String
+ protected def pythonMetrics: Map[String, SQLMetric]
+
protected def writeIteratorToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
@@ -115,6 +118,7 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
val arrowWriter = ArrowWriter.create(root)
while (inputIterator.hasNext) {
+ val startData = dataOut.size()
val nextBatch = inputIterator.next()
while (nextBatch.hasNext) {
@@ -124,6 +128,8 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index 339f114539c2..c12c690f776a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths}
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
@@ -37,6 +38,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column
*/
private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] =>
+ protected def pythonMetrics: Map[String, SQLMetric]
+
protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT
@@ -82,10 +85,15 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
}
try {
if (reader != null && batchLoaded) {
+ val bytesReadStart = reader.bytesRead()
batchLoaded = reader.loadNextBatch()
if (batchLoaded) {
val batch = new ColumnarBatch(vectors)
+ val rowCount = root.getRowCount
batch.setNumRows(root.getRowCount)
+ val bytesReadEnd = reader.bytesRead()
+ pythonMetrics("pythonNumRowsReceived") += rowCount
+ pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart
deserializeColumnarBatch(batch, schema)
} else {
reader.close(false)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
new file mode 100644
index 000000000000..a748c1bc1008
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+private[sql] trait PythonSQLMetrics { self: SparkPlan =>
+
+ val pythonMetrics = Map(
+ "pythonDataSent" -> SQLMetrics.createSizeMetric(sparkContext,
+ "data sent to Python workers"),
+ "pythonDataReceived" -> SQLMetrics.createSizeMetric(sparkContext,
+ "data returned from Python workers"),
+ "pythonNumRowsReceived" -> SQLMetrics.createMetric(sparkContext,
+ "number of output rows")
+ )
+
+ override lazy val metrics = pythonMetrics
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index d1109d251c28..09e06b55df3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark._
import org.apache.spark.api.python._
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
/**
@@ -31,7 +32,8 @@ import org.apache.spark.sql.internal.SQLConf
class PythonUDFRunner(
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
- argOffsets: Array[Array[Int]])
+ argOffsets: Array[Array[Int]],
+ pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[Array[Byte], Array[Byte]](
funcs, evalType, argOffsets) {
@@ -50,8 +52,13 @@ class PythonUDFRunner(
}
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
+ val startData = dataOut.size()
+
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
}
}
}
@@ -77,6 +84,7 @@ class PythonUDFRunner(
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
+ pythonMetrics("pythonDataReceived") += length
obj
case 0 => Array.emptyByteArray
case SpecialLengths.TIMING_DATA =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index ccb1ed92525d..dcaffed89cca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -84,7 +84,7 @@ case class WindowInPandasExec(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan)
- extends WindowExecBase {
+ extends WindowExecBase with PythonSQLMetrics {
/**
* Helper functions and data structures for window bounds
@@ -375,7 +375,8 @@ case class WindowInPandasExec(
argOffsets,
pythonInputSchema,
sessionLocalTimeZone,
- pythonRunnerConf).compute(pythonInput, context.partitionId(), context)
+ pythonRunnerConf,
+ pythonMetrics).compute(pythonInput, context.partitionId(), context)
val joined = new JoinedRow
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 2b8fc6515618..b540f9f00939 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.python.PythonSQLMetrics
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
@@ -93,7 +94,7 @@ trait StateStoreReader extends StatefulOperator {
}
/** An operator that writes to a StateStore. */
-trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
+trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: SparkPlan =>
override lazy val metrics = statefulOperatorCustomMetrics ++ Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
@@ -109,7 +110,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
"numShufflePartitions" -> SQLMetrics.createMetric(sparkContext, "number of shuffle partitions"),
"numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext,
"number of state store instances")
- ) ++ stateStoreCustomMetrics
+ ) ++ stateStoreCustomMetrics ++ pythonMetrics
/**
* Get the progress made by this stateful operator after execution. This should be called in
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
index 70784c20a8eb..7850b2d79b04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
@@ -84,4 +84,23 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession {
checkAnswer(actual, expected)
}
+
+ test("SPARK-34265: Instrument Python UDF execution using SQL Metrics") {
+
+ val pythonSQLMetrics = List(
+ "data sent to Python workers",
+ "data returned from Python workers",
+ "number of output rows")
+
+ val df = base.groupBy(pythonTestUDF(base("a") + 1))
+ .agg(pythonTestUDF(pythonTestUDF(base("a") + 1)))
+ df.count()
+
+ val statusStore = spark.sharedState.statusStore
+ val lastExecId = statusStore.executionsList.last.executionId
+ val executionMetrics = statusStore.execution(lastExecId).get.metrics.mkString
+ for (metric <- pythonSQLMetrics) {
+ assert(executionMetrics.contains(metric))
+ }
+ }
}