Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def __hash__(self):
"pyspark.sql.tests.pandas.test_pandas_udf_typehints",
"pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations",
"pyspark.sql.tests.pandas.test_pandas_udf_window",
"pyspark.sql.tests.test_pandas_sqlmetrics",
"pyspark.sql.tests.test_readwriter",
"pyspark.sql.tests.test_serde",
"pyspark.sql.tests.test_session",
Expand Down
2 changes: 2 additions & 0 deletions docs/web-ui.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ Here is the list of SQL metrics:
<tr><td> <code>time to build hash map</code> </td><td> the time spent on building hash map </td><td> ShuffledHashJoin </td></tr>
<tr><td> <code>task commit time</code> </td><td> the time spent on committing the output of a task after the writes succeed </td><td> any write operation on a file-based table </td></tr>
<tr><td> <code>job commit time</code> </td><td> the time spent on committing the output of a job after the writes succeed </td><td> any write operation on a file-based table </td></tr>
<tr><td> <code>data sent to Python workers</code> </td><td> the number of bytes of serialized data sent to the Python workers </td><td> ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas </td></tr>
<tr><td> <code>data returned from Python workers</code> </td><td> the number of bytes of serialized data received back from the Python workers </td><td> ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas </td></tr>
</table>

## Structured Streaming Tab
Expand Down
68 changes: 68 additions & 0 deletions python/pyspark/sql/tests/test_pandas_sqlmetrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest
from typing import cast

from pyspark.sql.functions import pandas_udf
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)


@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class PandasSQLMetrics(ReusedSQLTestCase):
def test_pandas_sql_metrics_basic(self):
# SPARK-34265: Instrument Python UDFs using SQL metrics

python_sql_metrics = [
"data sent to Python workers",
"data returned from Python workers",
"number of output rows",
]

@pandas_udf("long")
def test_pandas(col1):
return col1 * col1

self.spark.range(10).select(test_pandas("id")).collect()

statusStore = self.spark._jsparkSession.sharedState().statusStore()
lastExecId = statusStore.executionsList().last().executionId()
executionMetrics = statusStore.execution(lastExecId).get().metrics().mkString()

for metric in python_sql_metrics:
self.assertIn(metric, executionMetrics)


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_sqlmetrics import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ case class AggregateInPandasExec(
udfExpressions: Seq[PythonUDF],
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {
extends UnaryExecNode with PythonSQLMetrics {

override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

Expand Down Expand Up @@ -163,7 +163,8 @@ case class AggregateInPandasExec(
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)
pythonRunnerConf,
pythonMetrics).compute(projectedRowIter, context.partitionId(), context)

val joinedAttributes =
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
import org.apache.spark.sql.execution.streaming.GroupStateImpl
Expand All @@ -58,7 +59,8 @@ class ApplyInPandasWithStatePythonRunner(
stateEncoder: ExpressionEncoder[Row],
keySchema: StructType,
outputSchema: StructType,
stateValueSchema: StructType)
stateValueSchema: StructType,
val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
with PythonArrowInput[InType]
with PythonArrowOutput[OutType] {
Expand Down Expand Up @@ -116,6 +118,7 @@ class ApplyInPandasWithStatePythonRunner(
val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch)

while (inputIterator.hasNext) {
val startData = dataOut.size()
val (keyRow, groupState, dataIter) = inputIterator.next()
assert(dataIter.hasNext, "should have at least one data row!")
w.startNewGroup(keyRow, groupState)
Expand All @@ -126,6 +129,8 @@ class ApplyInPandasWithStatePythonRunner(
}

w.finalizeGroup()
val deltaData = dataOut.size() - startData
pythonMetrics("pythonDataSent") += deltaData
}

w.finalizeData()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int)
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error is from here:

java.lang.reflect.InvocationTargetException
	at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
	at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
	at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
	at java.lang.reflect.Constructor.newInstance(Constructor.java:423)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$makeCopy$6(TreeNode.scala:738)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:83)
	at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:737)
	at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:694)
	at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:99)
	at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:59)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:223)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:508)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:492)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:491)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:506)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.IterableLike.foreach(IterableLike.scala:74)
	at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
	at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:506)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:492)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:491)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.doCanonicalize(AdaptiveSparkPlanExec.scala:210)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.doCanonicalize(AdaptiveSparkPlanExec.scala:64)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:492)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:491)
	at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:850)
	at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:814)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:492)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:491)
	at org.apache.spark.sql.execution.ScalarSubquery.preCanonicalized$lzycompute(subquery.scala:72)
	at org.apache.spark.sql.execution.ScalarSubquery.preCanonicalized(subquery.scala:71)
	at org.apache.spark.sql.catalyst.expressions.CastBase.preCanonicalized$lzycompute(Cast.scala:319)
	at org.apache.spark.sql.catalyst.expressions.CastBase.preCanonicalized(Cast.scala:318)
	at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized$lzycompute(Expression.scala:261)
	at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized(Expression.scala:261)
	at org.apache.spark.sql.catalyst.expressions.Expression.semanticHash(Expression.scala:278)
	at org.apache.spark.sql.catalyst.expressions.ExpressionEquals.hashCode(EquivalentExpressions.scala:260)
	at scala.runtime.Statics.anyHash(Statics.java:122)
	at scala.collection.mutable.HashTable$HashUtils.elemHashCode(HashTable.scala:416)
	at scala.collection.mutable.HashTable$HashUtils.elemHashCode$(HashTable.scala:416)
	at scala.collection.mutable.HashMap.elemHashCode(HashMap.scala:44)
	at scala.collection.mutable.HashTable.findEntry(HashTable.scala:136)
	at scala.collection.mutable.HashTable.findEntry$(HashTable.scala:135)
	at scala.collection.mutable.HashMap.findEntry(HashMap.scala:44)
	at scala.collection.mutable.HashMap.get(HashMap.scala:74)
	at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.updateExprInMap(EquivalentExpressions.scala:59)
	at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.updateExprTree(EquivalentExpressions.scala:202)
	at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExprTree(EquivalentExpressions.scala:186)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.$anonfun$subexpressionElimination$1(CodeGenerator.scala:1218)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.$anonfun$subexpressionElimination$1$adapted(CodeGenerator.scala:1218)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.subexpressionElimination(CodeGenerator.scala:1218)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.generateExpressions(CodeGenerator.scala:1271)
	at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.create(GenerateMutableProjection.scala:64)
	at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.generate(GenerateMutableProjection.scala:49)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.createCodeGeneratedObject(Projection.scala:84)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.createCodeGeneratedObject(Projection.scala:80)
	at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:47)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.create(Projection.scala:95)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.create(Projection.scala:103)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:118)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:507)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1475)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:510)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.NullPointerException
	at org.apache.spark.sql.execution.SparkPlan.sparkContext(SparkPlan.scala:62)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.org$apache$spark$sql$execution$python$PythonSQLMetrics$$super$sparkContext(ArrowEvalPythonExec.scala:62)
	at org.apache.spark.sql.execution.python.PythonSQLMetrics.$init$(PythonSQLMetrics.scala:27)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.<init>(ArrowEvalPythonExec.scala:64)
	... 89 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2618)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2567)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2566)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2566)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2806)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2748)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2737)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2249)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2268)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2293)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1021)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1020)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:408)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.$anonfun$executeCollect$1(AdaptiveSparkPlanExec.scala:342)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.withFinalPlanUpdate(AdaptiveSparkPlanExec.scala:370)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.executeCollect(AdaptiveSparkPlanExec.scala:342)
	at org.apache.spark.sql.execution.SparkPlan.executeCollectPublic(SparkPlan.scala:435)
	at org.apache.spark.sql.execution.HiveResult$.hiveResultString(HiveResult.scala:76)
	at org.apache.spark.sql.SQLQueryTestHelper.$anonfun$getNormalizedResult$2(SQLQueryTestHelper.scala:66)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:109)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:169)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:95)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:779)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.SQLQueryTestHelper.getNormalizedResult(SQLQueryTestHelper.scala:66)
	at org.apache.spark.sql.SQLQueryTestHelper.getNormalizedResult$(SQLQueryTestHelper.scala:50)
	at org.apache.spark.sql.SQLQueryTestSuite.getNormalizedResult(SQLQueryTestSuite.scala:126)
	at org.apache.spark.sql.SQLQueryTestSuite.$anonfun$runQueries$8(SQLQueryTestSuite.scala:402)
	at org.apache.spark.sql.SQLQueryTestHelper.handleExceptions(SQLQueryTestHelper.scala:81)
	at org.apache.spark.sql.SQLQueryTestHelper.handleExceptions$(SQLQueryTestHelper.scala:79)
	at org.apache.spark.sql.SQLQueryTestSuite.handleExceptions(SQLQueryTestSuite.scala:126)
	at org.apache.spark.sql.SQLQueryTestSuite.$anonfun$runQueries$7(SQLQueryTestSuite.scala:402)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
	at org.apache.spark.sql.SQLQueryTestSuite.runQueries(SQLQueryTestSuite.scala:401)
	at org.apache.spark.sql.SQLQueryTestSuite.$anonfun$runTest$34(SQLQueryTestSuite.scala:349)
	at org.apache.spark.sql.SQLQueryTestSuite.$anonfun$runTest$34$adapted(SQLQueryTestSuite.scala:347)
	at scala.collection.immutable.List.foreach(List.scala:431)
	at org.apache.spark.sql.SQLQueryTestSuite.runTest(SQLQueryTestSuite.scala:347)
	at org.apache.spark.sql.SQLQueryTestSuite.$anonfun$createScalaTestCase$5(SQLQueryTestSuite.scala:254)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85)
	at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83)
	at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104)
	at org.scalatest.Transformer.apply(Transformer.scala:22)
	at org.scalatest.Transformer.apply(Transformer.scala:20)
	at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:190)
	at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:203)
	at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:188)
	at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:200)
	at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306)
	at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:200)
	at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:182)
	at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:64)
	at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234)
	at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227)
	at org.apache.spark.SparkFunSuite.runTest(SparkFunSuite.scala:64)
	at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:233)
	at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413)
	at scala.collection.immutable.List.foreach(List.scala:431)
	at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401)
	at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396)
	at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475)
	at org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:233)
	at org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:232)
	at org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1563)
	at org.scalatest.Suite.run(Suite.scala:1112)
	at org.scalatest.Suite.run$(Suite.scala:1094)
	at org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1563)
	at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:237)
	at org.scalatest.SuperEngine.runImpl(Engine.scala:535)
	at org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:237)
	at org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:236)
	at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:64)
	at org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213)
	at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210)
	at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208)
	at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:64)
	at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:318)
	at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:513)
	at sbt.ForkMain$Run.lambda$runTest$1(ForkMain.java:413)
	at java.util.concurrent.FutureTask.run(FutureTask.java:266)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.reflect.InvocationTargetException
	at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
	at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
	at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
	at java.lang.reflect.Constructor.newInstance(Constructor.java:423)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$makeCopy$6(TreeNode.scala:738)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:83)
	at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:737)
	at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:694)
	at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:99)
	at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:59)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:223)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:508)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:492)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:491)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:506)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.IterableLike.foreach(IterableLike.scala:74)
	at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
	at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:506)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:492)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:491)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.doCanonicalize(AdaptiveSparkPlanExec.scala:210)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.doCanonicalize(AdaptiveSparkPlanExec.scala:64)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:492)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:491)
	at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:850)
	at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:814)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:492)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:491)
	at org.apache.spark.sql.execution.ScalarSubquery.preCanonicalized$lzycompute(subquery.scala:72)
	at org.apache.spark.sql.execution.ScalarSubquery.preCanonicalized(subquery.scala:71)
	at org.apache.spark.sql.catalyst.expressions.CastBase.preCanonicalized$lzycompute(Cast.scala:319)
	at org.apache.spark.sql.catalyst.expressions.CastBase.preCanonicalized(Cast.scala:318)
	at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized$lzycompute(Expression.scala:261)
	at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized(Expression.scala:261)
	at org.apache.spark.sql.catalyst.expressions.Expression.semanticHash(Expression.scala:278)
	at org.apache.spark.sql.catalyst.expressions.ExpressionEquals.hashCode(EquivalentExpressions.scala:260)
	at scala.runtime.Statics.anyHash(Statics.java:122)
	at scala.collection.mutable.HashTable$HashUtils.elemHashCode(HashTable.scala:416)
	at scala.collection.mutable.HashTable$HashUtils.elemHashCode$(HashTable.scala:416)
	at scala.collection.mutable.HashMap.elemHashCode(HashMap.scala:44)
	at scala.collection.mutable.HashTable.findEntry(HashTable.scala:136)
	at scala.collection.mutable.HashTable.findEntry$(HashTable.scala:135)
	at scala.collection.mutable.HashMap.findEntry(HashMap.scala:44)
	at scala.collection.mutable.HashMap.get(HashMap.scala:74)
	at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.updateExprInMap(EquivalentExpressions.scala:59)
	at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.updateExprTree(EquivalentExpressions.scala:202)
	at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExprTree(EquivalentExpressions.scala:186)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.$anonfun$subexpressionElimination$1(CodeGenerator.scala:1218)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.$anonfun$subexpressionElimination$1$adapted(CodeGenerator.scala:1218)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.subexpressionElimination(CodeGenerator.scala:1218)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.generateExpressions(CodeGenerator.scala:1271)
	at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.create(GenerateMutableProjection.scala:64)
	at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.generate(GenerateMutableProjection.scala:49)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.createCodeGeneratedObject(Projection.scala:84)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.createCodeGeneratedObject(Projection.scala:80)
	at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:47)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.create(Projection.scala:95)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.create(Projection.scala:103)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:118)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:507)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1475)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:510)
	... 3 more
Caused by: java.lang.NullPointerException
	at org.apache.spark.sql.execution.SparkPlan.sparkContext(SparkPlan.scala:62)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.org$apache$spark$sql$execution$python$PythonSQLMetrics$$super$sparkContext(ArrowEvalPythonExec.scala:62)
	at org.apache.spark.sql.execution.python.PythonSQLMetrics.$init$(PythonSQLMetrics.scala:27)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.<init>(ArrowEvalPythonExec.scala:64)
	... 89 more

evalType: Int)
extends EvalPythonExec {
extends EvalPythonExec with PythonSQLMetrics {

private val batchSize = conf.arrowMaxRecordsPerBatch
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
Expand All @@ -85,7 +85,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
argOffsets,
schema,
sessionLocalTimeZone,
pythonRunnerConf).compute(batchIter, context.partitionId(), context)
pythonRunnerConf,
pythonMetrics).compute(batchIter, context.partitionId(), context)

columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python

import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -32,7 +33,8 @@ class ArrowPythonRunner(
argOffsets: Array[Array[Int]],
protected override val schema: StructType,
protected override val timeZoneId: String,
protected override val workerConf: Map[String, String])
protected override val workerConf: Map[String, String],
val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets)
with BasicPythonArrowInput
with BasicPythonArrowOutput {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
* A physical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec {
extends EvalPythonExec with PythonSQLMetrics {

protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
Expand Down Expand Up @@ -77,7 +77,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
}.grouped(100).map(x => pickle.dumps(x.toArray))

// Output iterator for results from Python.
val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
val outputIterator =
new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
Expand All @@ -94,6 +95,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
pythonMetrics("pythonNumRowsReceived") += 1
if (udfs.length == 1) {
// fast path for single UDF
mutableRow(0) = fromJava(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriter
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
Expand All @@ -45,7 +46,8 @@ class CoGroupedArrowPythonRunner(
leftSchema: StructType,
rightSchema: StructType,
timeZoneId: String,
conf: Map[String, String])
conf: Map[String, String],
val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[
(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets)
with BasicPythonArrowOutput {
Expand Down Expand Up @@ -77,10 +79,14 @@ class CoGroupedArrowPythonRunner(
// For each we first send the number of dataframes in each group then send
// first df, then send second df. End of data is marked by sending 0.
while (inputIterator.hasNext) {
val startData = dataOut.size()
dataOut.writeInt(2)
val (nextLeft, nextRight) = inputIterator.next()
writeGroup(nextLeft, leftSchema, dataOut, "left")
writeGroup(nextRight, rightSchema, dataOut, "right")

val deltaData = dataOut.size() - startData
pythonMetrics("pythonDataSent") += deltaData
}
dataOut.writeInt(0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ case class FlatMapCoGroupsInPandasExec(
output: Seq[Attribute],
left: SparkPlan,
right: SparkPlan)
extends SparkPlan with BinaryExecNode {
extends SparkPlan with BinaryExecNode with PythonSQLMetrics {

private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
Expand All @@ -77,7 +77,6 @@ case class FlatMapCoGroupsInPandasExec(
}

override protected def doExecute(): RDD[InternalRow] = {

val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup)

Expand All @@ -97,7 +96,8 @@ case class FlatMapCoGroupsInPandasExec(
StructType.fromAttributes(leftDedup),
StructType.fromAttributes(rightDedup),
sessionLocalTimeZone,
pythonRunnerConf)
pythonRunnerConf,
pythonMetrics)

executePython(data, output, runner)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class FlatMapGroupsInPandasExec(
func: Expression,
output: Seq[Attribute],
child: SparkPlan)
extends SparkPlan with UnaryExecNode {
extends SparkPlan with UnaryExecNode with PythonSQLMetrics {

private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
Expand Down Expand Up @@ -89,7 +89,8 @@ case class FlatMapGroupsInPandasExec(
Array(argOffsets),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g., define val localPythonMetrics = pythonMetrics at line 79, and replace pythonMetrics to localPythonMetrics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @HyukjinKwon for looking into this.
Unfortunately the proposed solution/test of using val localPythonMetrics = pythonMetrics does not appear to work.
Using lazy val for the metrics appears to break many tesdts with Python. In particular I can see in that case that when using pyspark and "going through rdd" as in df_with_udf.rdd.collect() we get java.lang.NullPointerException.
I would not propose to skip the failing test in postgreSQL/udf-aggregates_part3.sql, but rather move it to a Python test: see test_pandas_udf_nested in test_pandas_udf.py
However if we can understand more clearly where this issue comes from, all the better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to clarify that the proposed solution for the failing test is to move the test to a Python test in python/pyspark/sql/tests/test_pandas_udf.py where it runs OK. I would (try to) argue that the issue may come for the way the Python UDF test in udf-aggregates_part3.sql is executed via Scala and the fact that that particular test is referring to udf twice: udf((select udf(count(*)) which apparently creates an issue there, while it works fine in a Python test.

StructType.fromAttributes(dedupAttributes),
sessionLocalTimeZone,
pythonRunnerConf)
pythonRunnerConf,
pythonMetrics)

executePython(data, output, runner)
}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ case class FlatMapGroupsInPandasWithStateExec(
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {
child: SparkPlan)
extends UnaryExecNode with PythonSQLMetrics with FlatMapGroupsWithStateExecBase {

// TODO(SPARK-40444): Add the support of initial state.
override protected val initialStateDeserializer: Expression = null
Expand Down Expand Up @@ -166,7 +167,8 @@ case class FlatMapGroupsInPandasWithStateExec(
stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
groupingAttributes.toStructType,
outAttributes.toStructType,
stateType)
stateType,
pythonMetrics)

val context = TaskContext.get()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
* This is somewhat similar with [[FlatMapGroupsInPandasExec]] and
* `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow`
*/
trait MapInBatchExec extends UnaryExecNode {
trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
protected val func: Expression
protected val pythonEvalType: Int

Expand Down Expand Up @@ -75,7 +75,8 @@ trait MapInBatchExec extends UnaryExecNode {
argOffsets,
StructType(StructField("struct", outputTypes) :: Nil),
sessionLocalTimeZone,
pythonRunnerConf).compute(batchIter, context.partitionId(), context)
pythonRunnerConf,
pythonMetrics).compute(batchIter, context.partitionId(), context)

val unsafeProj = UnsafeProjection.create(output, output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, PythonRDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriter
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils
Expand All @@ -41,6 +42,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>

protected val timeZoneId: String

protected def pythonMetrics: Map[String, SQLMetric]

protected def writeIteratorToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
Expand Down Expand Up @@ -115,6 +118,7 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
val arrowWriter = ArrowWriter.create(root)

while (inputIterator.hasNext) {
val startData = dataOut.size()
val nextBatch = inputIterator.next()

while (nextBatch.hasNext) {
Expand All @@ -124,6 +128,8 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
val deltaData = dataOut.size() - startData
pythonMetrics("pythonDataSent") += deltaData
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
Expand All @@ -37,6 +38,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column
*/
private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] =>

protected def pythonMetrics: Map[String, SQLMetric]

protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }

protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT
Expand Down Expand Up @@ -82,10 +85,15 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
}
try {
if (reader != null && batchLoaded) {
val bytesReadStart = reader.bytesRead()
batchLoaded = reader.loadNextBatch()
if (batchLoaded) {
val batch = new ColumnarBatch(vectors)
val rowCount = root.getRowCount
batch.setNumRows(root.getRowCount)
val bytesReadEnd = reader.bytesRead()
pythonMetrics("pythonNumRowsReceived") += rowCount
pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart
deserializeColumnarBatch(batch, schema)
} else {
reader.close(false)
Expand Down
Loading