diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 0ecdc4bdef96c..41146e4ef688d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connect.execution +import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal import com.google.protobuf.Message @@ -185,19 +186,34 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends s"${executeHolder.request.getPlan.getOpTypeCase} not supported.") } - if (executeHolder.observations.nonEmpty) { - val observedMetrics = executeHolder.observations.map { case (name, observation) => + val observedMetrics: Map[String, Seq[(Option[String], Any)]] = { + executeHolder.observations.map { case (name, observation) => val values = observation.getOrEmpty.map { case (key, value) => (Some(key), value) }.toSeq name -> values }.toMap + } + val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = { + executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator => + accumulator.synchronized { + val value = accumulator.value.asScala.toSeq + if (value.nonEmpty) { + accumulator.reset() + Some("__python_accumulator__" -> value.map(value => (None, value))) + } else { + None + } + } + }.toMap + } + if (observedMetrics.nonEmpty || accumulatedInPython.nonEmpty) { executeHolder.responseObserver.onNext( SparkConnectPlanExecution .createObservedMetricsResponse( executeHolder.sessionHolder.sessionId, executeHolder.sessionHolder.serverSessionId, - observedMetrics)) + observedMetrics ++ accumulatedInPython)) } lock.synchronized { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 25c78413170ec..5be79a090a38c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -972,8 +972,8 @@ class SparkConnectPlanner( pythonVer = fun.getPythonVer, // Empty broadcast variables broadcastVars = Lists.newArrayList(), - // Null accumulator - accumulator = null) + // Accumulator if available + accumulator = sessionHolder.pythonAccumulator.orNull) } private def transformCachedRemoteRelation(rel: proto.CachedRemoteRelation): LogicalPlan = { @@ -1680,8 +1680,8 @@ class SparkConnectPlanner( pythonVer = fun.getPythonVer, // Empty broadcast variables broadcastVars = Lists.newArrayList(), - // Null accumulator - accumulator = null) + // Accumulator if available + accumulator = sessionHolder.pythonAccumulator.orNull) } /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 427b1a50588cc..ef79cdcce8ff5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -24,11 +24,13 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.jdk.CollectionConverters._ +import scala.util.Try import com.google.common.base.Ticker import com.google.common.cache.CacheBuilder import org.apache.spark.{SparkException, SparkSQLException} +import org.apache.spark.api.python.PythonFunction.PythonAccumulator import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession @@ -371,6 +373,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private[connect] def listListenerIds(): Seq[String] = { listenerCache.keySet().asScala.toSeq } + + /** + * An accumulator for Python executors. + * + * The accumulated results will be sent to the Python client via observed_metrics message. + */ + private[connect] val pythonAccumulator: Option[PythonAccumulator] = + Try(session.sparkContext.collectionAccumulator[Array[Byte]]).toOption } object SessionHolder { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 1d0c905164ad4..5aa080b5fb291 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.python.PythonFunction.PythonAccumulator import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging @@ -83,7 +84,11 @@ private[spark] trait PythonFunction { def pythonExec: String def pythonVer: String def broadcastVars: JList[Broadcast[PythonBroadcast]] - def accumulator: PythonAccumulatorV2 + def accumulator: PythonAccumulator +} + +private[spark] object PythonFunction { + type PythonAccumulator = CollectionAccumulator[Array[Byte]] } /** @@ -96,7 +101,7 @@ private[spark] case class SimplePythonFunction( pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: PythonAccumulatorV2) extends PythonFunction { + accumulator: PythonAccumulator) extends PythonFunction { def this( command: Array[Byte], @@ -105,7 +110,7 @@ private[spark] case class SimplePythonFunction( pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: PythonAccumulatorV2) = { + accumulator: PythonAccumulator) = { this(command.toImmutableArraySeq, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 148f80540d962..17cb0c5a55ddf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -30,6 +30,7 @@ import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal import org.apache.spark._ +import org.apache.spark.api.python.PythonFunction.PythonAccumulator import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES, Python} import org.apache.spark.internal.config.Python._ @@ -146,10 +147,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( }.getOrElse("pyspark.worker") // TODO: support accumulator in multiple UDF - protected val accumulator: PythonAccumulatorV2 = funcs.head.funcs.head.accumulator + protected val accumulator: PythonAccumulator = funcs.head.funcs.head.accumulator // Python accumulator is always set in production except in tests. See SPARK-27893 - private val maybeAccumulator: Option[PythonAccumulatorV2] = Option(accumulator) + private val maybeAccumulator: Option[PythonAccumulator] = Option(accumulator) // Expose a ServerSocket to support method calls via socket from Python side. Only relevant for // for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] for details. diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala index 3d0687553bb33..ae3614445be6e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala @@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream, File} import java.nio.charset.StandardCharsets import org.apache.spark.{SparkEnv, SparkFiles} +import org.apache.spark.api.python.PythonFunction.PythonAccumulator import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging @@ -186,7 +187,8 @@ private[spark] object PythonWorkerUtils extends Logging { * The updates are sent by `worker_util.send_accumulator_updates`. */ def receiveAccumulatorUpdates( - maybeAccumulator: Option[PythonAccumulatorV2], dataIn: DataInputStream): Unit = { + maybeAccumulator: Option[PythonAccumulator], + dataIn: DataInputStream): Unit = { val numAccumulatorUpdates = dataIn.readInt() (1 to numAccumulatorUpdates).foreach { _ => val update = readBytes(dataIn) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index ad164b1a86363..0500bf38ea8e4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -998,6 +998,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_column", "pyspark.sql.tests.connect.test_parity_readwriter", "pyspark.sql.tests.connect.test_parity_udf", + "pyspark.sql.tests.connect.test_parity_udf_profiler", "pyspark.sql.tests.connect.test_parity_udtf", "pyspark.sql.tests.connect.test_parity_pandas_udf", "pyspark.sql.tests.connect.test_parity_pandas_map", diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index a95bd9debfc21..4f61a9fbd9f73 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -57,6 +57,10 @@ def _deserialize_accumulator( return accum +class SpecialAccumulatorIds: + SQL_UDF_PROFIER = -1 + + class Accumulator(Generic[T]): """ diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index b7ea6a19063b6..b5f1bc4d714d8 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -422,8 +422,9 @@ def stats(self) -> CodeMapDict: """Return the collected memory profiles""" return cast(CodeMapDict, self._accumulator.value) + @staticmethod def _show_results( - self, code_map: CodeMapDict, stream: Optional[Any] = None, precision: int = 1 + code_map: CodeMapDict, stream: Optional[Any] = None, precision: int = 1 ) -> None: if stream is None: stream = sys.stdout diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi index cee44c4aa069e..b696eea7293fb 100644 --- a/python/pyspark/sql/_typing.pyi +++ b/python/pyspark/sql/_typing.pyi @@ -19,6 +19,7 @@ from typing import ( Any, Callable, + Dict, List, Optional, Tuple, @@ -29,8 +30,10 @@ from typing_extensions import Literal, Protocol import datetime import decimal +import pstats from pyspark._typing import PrimitiveType +from pyspark.profiler import CodeMapDict import pyspark.sql.types from pyspark.sql.column import Column @@ -79,3 +82,5 @@ class UserDefinedFunctionLike(Protocol): def returnType(self) -> pyspark.sql.types.DataType: ... def __call__(self, *args: ColumnOrName) -> Column: ... def asNondeterministic(self) -> UserDefinedFunctionLike: ... + +ProfileResults = Dict[int, Tuple[Optional[pstats.Stats], Optional[CodeMapDict]]] diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 10235fd7d6c47..c1c046e93708c 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -54,11 +54,13 @@ from google.protobuf import text_format, any_pb2 from google.rpc import error_details_pb2 +from pyspark.accumulators import SpecialAccumulatorIds from pyspark.loose_version import LooseVersion from pyspark.version import __version__ from pyspark.resource.information import ResourceInformation from pyspark.sql.connect.client.artifact import ArtifactManager from pyspark.sql.connect.client.logging import logger +from pyspark.sql.connect.profiler import ConnectProfilerCollector from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level @@ -636,6 +638,8 @@ class ClientThreadLocals(threading.local): # be updated on the first response received. self._server_session_id: Optional[str] = None + self._profiler_collector = ConnectProfilerCollector() + def _retrying(self) -> "Retrying": return Retrying(self._retry_policies) @@ -1169,7 +1173,14 @@ def handle_response( if b.observed_metrics: logger.debug("Received observed metric batch.") for observed_metrics in self._build_observed_metrics(b.observed_metrics): - if observed_metrics.name in observations: + if observed_metrics.name == "__python_accumulator__": + from pyspark.worker_util import pickleSer + + for metric in observed_metrics.metrics: + (aid, update) = pickleSer.loads(LiteralExpression._to_value(metric)) + if aid == SpecialAccumulatorIds.SQL_UDF_PROFIER: + self._profiler_collector._update(update) + elif observed_metrics.name in observations: observation_result = observations[observed_metrics.name]._result assert observation_result is not None observation_result.update( diff --git a/python/pyspark/sql/connect/profiler.py b/python/pyspark/sql/connect/profiler.py new file mode 100644 index 0000000000000..b8825cf5678eb --- /dev/null +++ b/python/pyspark/sql/connect/profiler.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import TYPE_CHECKING + +from pyspark.sql.profiler import ProfilerCollector, ProfileResultsParam + +if TYPE_CHECKING: + from pyspark.sql._typing import ProfileResults + + +class ConnectProfilerCollector(ProfilerCollector): + """ + ProfilerCollector for Spark Connect. + """ + + def __init__(self) -> None: + super().__init__() + self._value = ProfileResultsParam.zero(None) + + @property + def _profile_results(self) -> "ProfileResults": + with self._lock: + return self._value if self._value is not None else {} + + def _update(self, update: "ProfileResults") -> None: + with self._lock: + self._value = ProfileResultsParam.addInPlace(self._profile_results, update) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index a27e6fa4b7290..5cbcb4ab5c350 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -62,6 +62,7 @@ CachedRelation, CachedRemoteRelation, ) +from pyspark.sql.connect.profiler import ProfilerCollector from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming.readwriter import DataStreamReader from pyspark.sql.connect.streaming.query import StreamingQueryManager @@ -919,6 +920,15 @@ def create_conf(**kwargs: Any) -> SparkConf: def session_id(self) -> str: return self._session_id + @property + def _profiler_collector(self) -> ProfilerCollector: + return self._client._profiler_collector + + def showPerfProfiles(self, id: Optional[int] = None) -> None: + self._profiler_collector.show_perf_profiles(id) + + showPerfProfiles.__doc__ = PySparkSession.showPerfProfiles.__doc__ + SparkSession.__doc__ = PySparkSession.__doc__ diff --git a/python/pyspark/sql/profiler.py b/python/pyspark/sql/profiler.py new file mode 100644 index 0000000000000..565752197238f --- /dev/null +++ b/python/pyspark/sql/profiler.py @@ -0,0 +1,176 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC, abstractmethod +import pstats +from threading import RLock +from typing import Dict, Optional, TYPE_CHECKING + +from pyspark.accumulators import ( + Accumulator, + AccumulatorParam, + SpecialAccumulatorIds, + _accumulatorRegistry, +) +from pyspark.profiler import CodeMapDict, MemoryProfiler, MemUsageParam, PStatsParam + +if TYPE_CHECKING: + from pyspark.sql._typing import ProfileResults + + +class _ProfileResultsParam(AccumulatorParam[Optional["ProfileResults"]]): + """ + AccumulatorParam for profilers. + """ + + @staticmethod + def zero(value: Optional["ProfileResults"]) -> Optional["ProfileResults"]: + return value + + @staticmethod + def addInPlace( + value1: Optional["ProfileResults"], value2: Optional["ProfileResults"] + ) -> Optional["ProfileResults"]: + if value1 is None or len(value1) == 0: + value1 = {} + if value2 is None or len(value2) == 0: + value2 = {} + + value = value1.copy() + for key, (perf, mem, *_) in value2.items(): + if key in value1: + orig_perf, orig_mem, *_ = value1[key] + else: + orig_perf, orig_mem = (PStatsParam.zero(None), MemUsageParam.zero(None)) + value[key] = ( + PStatsParam.addInPlace(orig_perf, perf), + MemUsageParam.addInPlace(orig_mem, mem), + ) + return value + + +ProfileResultsParam = _ProfileResultsParam() + + +class ProfilerCollector(ABC): + """ + A base class of profiler collectors for session based profilers. + + This supports cProfiler and memory-profiler enabled by setting a SQL config + `spark.sql.pyspark.udf.profiler` to "perf" or "memory". + """ + + def __init__(self) -> None: + self._lock = RLock() + + def show_perf_profiles(self, id: Optional[int] = None) -> None: + """ + Show the perf profile results. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + id : int, optional + A UDF ID to be shown. If not specified, all the results will be shown. + """ + with self._lock: + stats = self._perf_profile_results + + def show(id: int) -> None: + s = stats.get(id) + if s is not None: + print("=" * 60) + print(f"Profile of UDF") + print("=" * 60) + s.sort_stats("time", "cumulative").print_stats() + + if id is not None: + show(id) + else: + for id in sorted(stats.keys()): + show(id) + + @property + def _perf_profile_results(self) -> Dict[int, pstats.Stats]: + with self._lock: + return { + result_id: perf + for result_id, (perf, _, *_) in self._profile_results.items() + if perf is not None + } + + def show_memory_profiles(self, id: Optional[int] = None) -> None: + """ + Show the memory profile results. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + id : int, optional + A UDF ID to be shown. If not specified, all the results will be shown. + """ + with self._lock: + code_map = self._memory_profile_results + + def show(id: int) -> None: + cm = code_map.get(id) + if cm is not None: + print("=" * 60) + print(f"Profile of UDF") + print("=" * 60) + MemoryProfiler._show_results(cm) + + if id is not None: + show(id) + else: + for id in sorted(code_map.keys()): + show(id) + + @property + def _memory_profile_results(self) -> Dict[int, CodeMapDict]: + with self._lock: + return { + result_id: mem + for result_id, (_, mem, *_) in self._profile_results.items() + if mem is not None + } + + @property + @abstractmethod + def _profile_results(self) -> "ProfileResults": + """ + Get the profile results. + """ + ... + + +class AccumulatorProfilerCollector(ProfilerCollector): + def __init__(self) -> None: + super().__init__() + if SpecialAccumulatorIds.SQL_UDF_PROFIER in _accumulatorRegistry: + self._accumulator = _accumulatorRegistry[SpecialAccumulatorIds.SQL_UDF_PROFIER] + else: + self._accumulator = Accumulator( + SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam + ) + + @property + def _profile_results(self) -> "ProfileResults": + with self._lock: + value = self._accumulator.value + return value if value is not None else {} diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 10b56d006dcd5..fef834b9f0a0a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -47,6 +47,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.functions import lit from pyspark.sql.pandas.conversion import SparkConversionMixin +from pyspark.sql.profiler import AccumulatorProfilerCollector, ProfilerCollector from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.sql_formatter import SQLStringFormatter from pyspark.sql.streaming import DataStreamReader @@ -623,6 +624,8 @@ def __init__( self._jvm.SparkSession.setDefaultSession(self._jsparkSession) self._jvm.SparkSession.setActiveSession(self._jsparkSession) + self._profiler_collector = AccumulatorProfilerCollector() + def _repr_html_(self) -> str: return """
@@ -2110,6 +2113,11 @@ def clearTags(self) -> None: message_parameters={"feature": "SparkSession.clearTags"}, ) + def showPerfProfiles(self, id: Optional[int] = None) -> None: + self._profiler_collector.show_perf_profiles(id) + + showPerfProfiles.__doc__ = ProfilerCollector.show_perf_profiles.__doc__ + def _test() -> None: import os diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py new file mode 100644 index 0000000000000..463d924410941 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import os +import unittest + +from pyspark.sql.tests.test_udf_profiler import UDFProfiler2TestsMixin, _do_computation +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase): + def setUp(self) -> None: + super().setUp() + self.spark._profiler_collector._value = None + + def test_perf_profiler_udf_multiple_actions(self): + def action(df): + df.collect() + df.show() + + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + _do_computation(self.spark, action=action) + + self.assertEqual(6, len(self.profile_results), str(list(self.profile_results))) + + for id in self.profile_results: + with self.trap_stdout() as io: + self.spark.showPerfProfiles(id) + + self.assertIn(f"Profile of UDF", io.getvalue()) + self.assertRegex( + io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}" + ) + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_udf_profiler import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py index 776d5da88bb27..d365523e456c9 100644 --- a/python/pyspark/sql/tests/test_udf_profiler.py +++ b/python/pyspark/sql/tests/test_udf_profiler.py @@ -15,18 +15,40 @@ # limitations under the License. # +from contextlib import contextmanager +import inspect import tempfile import unittest import os import sys import warnings from io import StringIO -from typing import Iterator +from typing import Iterator, cast from pyspark import SparkConf from pyspark.sql import SparkSession -from pyspark.sql.functions import udf, pandas_udf +from pyspark.sql.functions import col, pandas_udf, udf from pyspark.profiler import UDFBasicProfiler +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + + +def _do_computation(spark, *, action=lambda df: df.collect(), use_arrow=False): + @udf("long", useArrow=use_arrow) + def add1(x): + return x + 1 + + @udf("long", useArrow=use_arrow) + def add2(x): + return x + 2 + + df = spark.range(10).select(add1("id"), add2("id"), add1("id"), add2(col("id") + 1)) + action(df) class UDFProfilerTests(unittest.TestCase): @@ -47,10 +69,10 @@ def tearDown(self): sys.path = self._old_sys_path def test_udf_profiler(self): - self.do_computation() + _do_computation(self.spark) profilers = self.sc.profiler_collector.profilers - self.assertEqual(3, len(profilers)) + self.assertEqual(4, len(profilers)) old_stdout = sys.stdout try: @@ -62,7 +84,7 @@ def test_udf_profiler(self): d = tempfile.gettempdir() self.sc.dump_profiles(d) - for i, udf_name in enumerate(["add1", "add2", "add1"]): + for i, udf_name in enumerate(["add1", "add2", "add1", "add2"]): id, profiler, _ = profilers[i] with self.subTest(id=id, udf_name=udf_name): stats = profiler.stats() @@ -81,28 +103,16 @@ def show(self, id): self.sc.profiler_collector.udf_profiler_cls = TestCustomProfiler - self.do_computation() + _do_computation(self.spark) profilers = self.sc.profiler_collector.profilers - self.assertEqual(3, len(profilers)) + self.assertEqual(4, len(profilers)) _, profiler, _ = profilers[0] self.assertTrue(isinstance(profiler, TestCustomProfiler)) self.sc.show_profiles() self.assertEqual("Custom formatting", profiler.result) - def do_computation(self): - @udf - def add1(x): - return x + 1 - - @udf - def add2(x): - return x + 2 - - df = self.spark.range(10) - df.select(add1("id"), add2("id"), add1("id")).collect() - # Unsupported def exec_pandas_udf_iter_to_iter(self): import pandas as pd @@ -145,6 +155,190 @@ def test_unsupported(self): ) +class UDFProfiler2TestsMixin: + @contextmanager + def trap_stdout(self): + old_stdout = sys.stdout + sys.stdout = io = StringIO() + try: + yield io + finally: + sys.stdout = old_stdout + + @property + def profile_results(self): + return self.spark._profiler_collector._perf_profile_results + + def test_perf_profiler_udf(self): + _do_computation(self.spark) + + # Without the conf enabled, no profile results are collected. + self.assertEqual(0, len(self.profile_results), str(list(self.profile_results))) + + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + _do_computation(self.spark) + + self.assertEqual(3, len(self.profile_results), str(list(self.profile_results))) + + with self.trap_stdout() as io_all: + self.spark.showPerfProfiles() + + for id in self.profile_results: + self.assertIn(f"Profile of UDF", io_all.getvalue()) + + with self.trap_stdout() as io: + self.spark.showPerfProfiles(id) + + self.assertIn(f"Profile of UDF", io.getvalue()) + self.assertRegex( + io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}" + ) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), + ) + def test_perf_profiler_udf_with_arrow(self): + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + _do_computation(self.spark, use_arrow=True) + + self.assertEqual(3, len(self.profile_results), str(list(self.profile_results))) + + for id in self.profile_results: + with self.trap_stdout() as io: + self.spark.showPerfProfiles(id) + + self.assertIn(f"Profile of UDF", io.getvalue()) + self.assertRegex( + io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}" + ) + + def test_perf_profiler_udf_multiple_actions(self): + def action(df): + df.collect() + df.show() + + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + _do_computation(self.spark, action=action) + + self.assertEqual(3, len(self.profile_results), str(list(self.profile_results))) + + for id in self.profile_results: + with self.trap_stdout() as io: + self.spark.showPerfProfiles(id) + + self.assertIn(f"Profile of UDF", io.getvalue()) + self.assertRegex( + io.getvalue(), f"20.*{os.path.basename(inspect.getfile(_do_computation))}" + ) + + def test_perf_profiler_udf_registered(self): + @udf("long") + def add1(x): + return x + 1 + + self.spark.udf.register("add1", add1) + + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + self.spark.sql("SELECT id, add1(id) add1 FROM range(10)").collect() + + self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys())) + + for id in self.profile_results: + with self.trap_stdout() as io: + self.spark.showPerfProfiles(id) + + self.assertIn(f"Profile of UDF", io.getvalue()) + self.assertRegex( + io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}" + ) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), + ) + def test_perf_profiler_pandas_udf(self): + @pandas_udf("long") + def add1(x): + return x + 1 + + @pandas_udf("long") + def add2(x): + return x + 2 + + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + df = self.spark.range(10, numPartitions=2).select( + add1("id"), add2("id"), add1("id"), add2(col("id") + 1) + ) + df.collect() + + self.assertEqual(3, len(self.profile_results), str(self.profile_results.keys())) + + for id in self.profile_results: + with self.trap_stdout() as io: + self.spark.showPerfProfiles(id) + + self.assertIn(f"Profile of UDF", io.getvalue()) + self.assertRegex( + io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}" + ) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), + ) + def test_perf_profiler_pandas_udf_iterator_not_supported(self): + import pandas as pd + + @pandas_udf("long") + def add1(x): + return x + 1 + + @pandas_udf("long") + def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]: + for s in iter: + yield s + 2 + + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + df = self.spark.range(10, numPartitions=2).select( + add1("id"), add2("id"), add1("id"), add2(col("id") + 1) + ) + df.collect() + + self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys())) + + for id in self.profile_results: + with self.trap_stdout() as io: + self.spark.showPerfProfiles(id) + + self.assertIn(f"Profile of UDF", io.getvalue()) + self.assertRegex( + io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}" + ) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), + ) + def test_perf_profiler_map_in_pandas_not_supported(self): + df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) + + def filter_func(iterator): + for pdf in iterator: + yield pdf[pdf.id == 1] + + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + df.mapInPandas(filter_func, df.schema).show() + + self.assertEqual(0, len(self.profile_results), str(self.profile_results.keys())) + + +class UDFProfiler2Tests(UDFProfiler2TestsMixin, ReusedSQLTestCase): + def setUp(self) -> None: + super().setUp() + self.spark._profiler_collector._accumulator._value = None + + if __name__ == "__main__": from pyspark.sql.tests.test_udf_profiler import * # noqa: F401 diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 158e3ae62bb75..3e3592a8ffa2a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,7 +27,11 @@ from typing import Any, Callable, Iterable, Iterator, Optional import faulthandler -from pyspark.accumulators import _accumulatorRegistry +from pyspark.accumulators import ( + SpecialAccumulatorIds, + _accumulatorRegistry, + _deserialize_accumulator, +) from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.resource import ResourceInformation @@ -688,7 +692,40 @@ def func(*args): return f, args_offsets -def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): +def _supports_profiler(eval_type: int) -> bool: + return eval_type not in ( + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_MAP_ARROW_ITER_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + ) + + +def wrap_perf_profiler(f, result_id): + import cProfile + import pstats + + from pyspark.sql.profiler import ProfileResultsParam + + accumulator = _deserialize_accumulator( + SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam + ) + + def profiling_func(*args, **kwargs): + pr = cProfile.Profile() + ret = pr.runcall(f, *args, **kwargs) + st = pstats.Stats(pr) + st.stream = None # make it picklable + st.strip_dirs() + + accumulator.add({result_id: (st, None)}) + + return ret + + return profiling_func + + +def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profiler): num_arg = read_int(infile) if eval_type in ( @@ -721,15 +758,31 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): else: chained_func = chain(chained_func, f) + if profiler == "perf": + result_id = read_long(infile) + + if _supports_profiler(eval_type): + profiling_func = wrap_perf_profiler(chained_func, result_id) + else: + profiling_func = chained_func + + elif profiler == "memory": + # TODO(SPARK-46687): Implement memory profiler + result_id = read_long(infile) + profiling_func = chained_func + + else: + profiling_func = chained_func + if eval_type in ( PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, ): - func = chained_func + func = profiling_func else: # make sure StopIteration's raised in the user code are not ignored # when they are processed in a for loop, raise them as RuntimeError's instead - func = fail_on_stopiteration(chained_func) + func = fail_on_stopiteration(profiling_func) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: @@ -1403,6 +1456,12 @@ def read_udfs(pickleSer, infile, eval_type): else: ser = BatchedSerializer(CPickleSerializer(), 100) + is_profiling = read_bool(infile) + if is_profiling: + profiler = utf8_deserializer.loads(infile) + else: + profiler = None + num_udfs = read_int(infile) is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF @@ -1417,7 +1476,9 @@ def read_udfs(pickleSer, infile, eval_type): if is_map_arrow_iter: assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + ) def func(_, iterator): num_input_rows = 0 @@ -1507,7 +1568,9 @@ def extract_key_value_indexes(grouped_arg_offsets): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + ) parsed_offsets = extract_key_value_indexes(arg_offsets) # Create function like this: @@ -1526,7 +1589,9 @@ def mapper(a): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + ) parsed_offsets = extract_key_value_indexes(arg_offsets) def batch_from_offset(batch, offsets): @@ -1550,7 +1615,9 @@ def mapper(a): # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + ) parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(a): @@ -1584,7 +1651,9 @@ def mapper(a): # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + ) parsed_offsets = extract_key_value_indexes(arg_offsets) @@ -1601,7 +1670,9 @@ def mapper(a): # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + ) parsed_offsets = extract_key_value_indexes(arg_offsets) @@ -1624,7 +1695,11 @@ def mapper(a): else: udfs = [] for i in range(num_udfs): - udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i)) + udfs.append( + read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=i, profiler=profiler + ) + ) def mapper(a): result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 743a2e20c8858..eb5233bfb1231 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2927,6 +2927,17 @@ object SQLConf { // show full stacktrace in tests but hide in production by default. .createWithDefault(Utils.isTesting) + val PYTHON_UDF_PROFILER = + buildConf("spark.sql.pyspark.udf.profiler") + .doc("Configure the Python/Pandas UDF profiler by enabling or disabling it " + + "with the option to choose between \"perf\" and \"memory\" types, " + + "or unsetting the config disables the profiler. This is disabled by default.") + .version("4.0.0") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(Set("perf", "memory")) + .createOptional + val PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED = buildConf("spark.sql.execution.pyspark.udf.faulthandler.enabled") .doc( @@ -5296,6 +5307,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED) + def pythonUDFProfiler: Option[String] = getConf(PYTHON_UDF_PROFILER) + def pythonUDFWorkerFaulthandlerEnabled: Boolean = getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED) def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala index f11a63429d78b..0e5f359ee76f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala @@ -128,7 +128,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) new MapInBatchEvaluatorFactory( toAttributes(outputSchema), - Seq(ChainedPythonFunctions(Seq(pythonUDF.func))), + Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)), inputSchema, conf.arrowMaxRecordsPerBatch, pythonEvalType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 7e349b665f352..8763731774478 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -85,15 +85,15 @@ case class AggregateInPandasExec( } private def collectFunctions( - udf: PythonFuncExpression): (ChainedPythonFunctions, Seq[Expression]) = { + udf: PythonFuncExpression): ((ChainedPythonFunctions, Long), Seq[Expression]) = { udf.children match { case Seq(u: PythonFuncExpression) => - val (chained, children) = collectFunctions(u) - (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + val ((chained, _), children) = collectFunctions(u) + ((ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), udf.resultId.id), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. assert(children.forall(!_.exists(_.isInstanceOf[PythonFuncExpression]))) - (ChainedPythonFunctions(Seq(udf.func)), udf.children) + ((ChainedPythonFunctions(Seq(udf.func)), udf.resultId.id), udf.children) } } @@ -180,7 +180,9 @@ case class AggregateInPandasExec( largeVarTypes, pythonRunnerConf, pythonMetrics, - jobArtifactUUID).compute(projectedRowIter, context.partitionId(), context) + jobArtifactUUID, + None) // TODO(SPARK-46688): Support profiling on AggregateInPandasExec + .compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 8795374b2a723..8eeb919d0bafd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -51,7 +51,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} * and output along with data, which requires different struct on Arrow RecordBatch. */ class ApplyInPandasWithStatePythonRunner( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, argOffsets: Array[Array[Int]], inputSchema: StructType, @@ -63,13 +63,13 @@ class ApplyInPandasWithStatePythonRunner( stateValueSchema: StructType, override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) - extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets, jobArtifactUUID) + extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, argOffsets, jobArtifactUUID) with PythonArrowInput[InType] with PythonArrowOutput[OutType] { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( - funcs.head.funcs.head.pythonExec) + funcs.head._1.funcs.head.pythonExec) override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled @@ -108,7 +108,7 @@ class ApplyInPandasWithStatePythonRunner( private val stateRowDeserializer = stateEncoder.createDeserializer() override protected def writeUDF(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index a6937b7bf89ca..da4c5bff34459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -77,7 +77,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] conf.arrowUseLargeVarTypes, ArrowPythonRunner.getPythonRunnerConfMap(conf), pythonMetrics, - jobArtifactUUID) + jobArtifactUUID, + conf.pythonUDFProfiler) } override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -94,11 +95,12 @@ class ArrowEvalPythonEvaluatorFactory( largeVarTypes: Boolean, pythonRunnerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + profiler: Option[String]) extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { override def evaluate( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], argMetas: Array[Array[ArgumentMetadata]], iter: Iterator[InternalRow], schema: StructType, @@ -118,7 +120,8 @@ class ArrowEvalPythonEvaluatorFactory( largeVarTypes, pythonRunnerConf, pythonMetrics, - jobArtifactUUID).compute(batchIter, context.partitionId(), context) + jobArtifactUUID, + profiler).compute(batchIter, context.partitionId(), context) columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 33933b64bbaf7..a555d660ea1ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch abstract class BaseArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, argOffsets: Array[Array[Int]], _schema: StructType, @@ -38,13 +38,13 @@ abstract class BaseArrowPythonRunner( override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, evalType, argOffsets, jobArtifactUUID) + funcs.map(_._1), evalType, argOffsets, jobArtifactUUID) with BasicPythonArrowInput with BasicPythonArrowOutput { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( - funcs.head.funcs.head.pythonExec) + funcs.head._1.funcs.head.pythonExec) override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled @@ -67,7 +67,7 @@ abstract class BaseArrowPythonRunner( * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. */ class ArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, argOffsets: Array[Array[Int]], _schema: StructType, @@ -75,13 +75,14 @@ class ArrowPythonRunner( largeVarTypes: Boolean, workerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + profiler: Option[String]) extends BaseArrowPythonRunner( funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, pythonMetrics, jobArtifactUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler) } /** @@ -89,7 +90,7 @@ class ArrowPythonRunner( * via Arrow stream. */ class ArrowPythonWithNamedArgumentRunner( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, argMetas: Array[Array[ArgumentMetadata]], _schema: StructType, @@ -97,13 +98,14 @@ class ArrowPythonWithNamedArgumentRunner( largeVarTypes: Boolean, workerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + profiler: Option[String]) extends BaseArrowPythonRunner( funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf, pythonMetrics, jobArtifactUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = - PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas) + PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas, profiler) } object ArrowPythonRunner { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 04d71c6c0153c..e6958392cad48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -44,7 +44,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] udfs, output, pythonMetrics, - jobArtifactUUID) + jobArtifactUUID, + conf.pythonUDFProfiler) } override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonExec = @@ -56,11 +57,12 @@ class BatchEvalPythonEvaluatorFactory( udfs: Seq[PythonUDF], output: Seq[Attribute], pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { + jobArtifactUUID: Option[String], + profiler: Option[String]) + extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { override def evaluate( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], argMetas: Array[Array[ArgumentMetadata]], iter: Iterator[InternalRow], schema: StructType, @@ -73,7 +75,7 @@ class BatchEvalPythonEvaluatorFactory( // Output iterator for results from Python. val outputIterator = new PythonUDFWithNamedArgumentsRunner( - funcs, PythonEvalType.SQL_BATCHED_UDF, argMetas, pythonMetrics, jobArtifactUUID) + funcs, PythonEvalType.SQL_BATCHED_UDF, argMetas, pythonMetrics, jobArtifactUUID, profiler) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index e6b19910296e3..9eebd4ea7e79c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -98,7 +98,7 @@ class PythonUDTFRunner( pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonUDFRunner( - Seq(ChainedPythonFunctions(Seq(udtf.func))), + Seq((ChainedPythonFunctions(Seq(udtf.func)), udtf.resultId.id)), PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)), pythonMetrics, jobArtifactUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 7e1c8c2ffc13b..5670cad67e7b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils * groups them in Python, and receive it back in JVM as batches of single DataFrame. */ class CoGroupedArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, argOffsets: Array[Array[Int]], leftSchema: StructType, @@ -47,15 +47,16 @@ class CoGroupedArrowPythonRunner( timeZoneId: String, conf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + profiler: Option[String]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch]( - funcs, evalType, argOffsets, jobArtifactUUID) + funcs.map(_._1), evalType, argOffsets, jobArtifactUUID) with BasicPythonArrowOutput { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( - funcs.head.funcs.head.pythonExec) + funcs.head._1.funcs.head.pythonExec) override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled @@ -79,7 +80,7 @@ class CoGroupedArrowPythonRunner( PythonRDD.writeUTF(v, dataOut) } - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler) } override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala index d5142f58eab47..34f9be0aa633d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala @@ -36,7 +36,7 @@ abstract class EvalPythonEvaluatorFactory( extends PartitionEvaluatorFactory[InternalRow, InternalRow] { protected def evaluate( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], argMetas: Array[Array[ArgumentMetadata]], iter: Iterator[InternalRow], schema: StructType, @@ -47,15 +47,16 @@ abstract class EvalPythonEvaluatorFactory( private class EvalPythonPartitionEvaluator extends PartitionEvaluator[InternalRow, InternalRow] { - private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + private def collectFunctions( + udf: PythonUDF): ((ChainedPythonFunctions, Long), Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => - val (chained, children) = collectFunctions(u) - (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + val ((chained, _), children) = collectFunctions(u) + ((ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), udf.resultId.id), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF]))) - (ChainedPythonFunctions(Seq(udf.func)), udf.children) + ((ChainedPythonFunctions(Seq(udf.func)), udf.resultId.id), udf.children) } } override def eval( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala index 97aa1495670fd..bc6f9859ec28a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala @@ -43,8 +43,10 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with BinaryExecNode with Pyth private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - private val pandasFunction = func.asInstanceOf[PythonUDF].func - private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + private val pythonUDF = func.asInstanceOf[PythonUDF] + private val pandasFunction = pythonUDF.func + private val chainedFunc = + Seq((ChainedPythonFunctions(Seq(pandasFunction)), pythonUDF.resultId.id)) override def producedAttributes: AttributeSet = AttributeSet(output) @@ -84,7 +86,8 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with BinaryExecNode with Pyth sessionLocalTimeZone, pythonRunnerConf, pythonMetrics, - jobArtifactUUID) + jobArtifactUUID, + None) // TODO(SPARK-46690): Support profiling on FlatMapCoGroupsInBatchExec executePython(data, output, runner) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala index facf7bc49c5ab..580ef46e842d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala @@ -43,8 +43,10 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with UnaryExecNode with PythonS private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val largeVarTypes = conf.arrowUseLargeVarTypes private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - private val pythonFunction = func.asInstanceOf[PythonUDF].func - private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + private val pythonUDF = func.asInstanceOf[PythonUDF] + private val pythonFunction = pythonUDF.func + private val chainedFunc = + Seq((ChainedPythonFunctions(Seq(pythonFunction)), pythonUDF.resultId.id)) private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) override def producedAttributes: AttributeSet = AttributeSet(output) @@ -89,7 +91,8 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with UnaryExecNode with PythonS largeVarTypes, pythonRunnerConf, pythonMetrics, - jobArtifactUUID) + jobArtifactUUID, + None) // TODO(SPARK-46689): Support profiling on FlatMapGroupsInBatchExec executePython(data, output, runner) }} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 105c5ca6493e7..850ee016e3631 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -82,8 +82,10 @@ case class FlatMapGroupsInPandasWithStateExec( private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func - private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + private val pythonUDF = functionExpr.asInstanceOf[PythonUDF] + private val pythonFunction = pythonUDF.func + private val chainedFunc = + Seq((ChainedPythonFunctions(Seq(pythonFunction)), pythonUDF.resultId.id)) private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( groupingAttributes ++ child.output, groupingAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index 29dc6e0aa541f..8d65fe558937f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} class MapInBatchEvaluatorFactory( output: Seq[Attribute], - chainedFunc: Seq[ChainedPythonFunctions], + chainedFunc: Seq[(ChainedPythonFunctions, Long)], outputTypes: StructType, batchSize: Int, pythonEvalType: Int, @@ -38,7 +38,7 @@ class MapInBatchEvaluatorFactory( pythonRunnerConf: Map[String, String], val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) - extends PartitionEvaluatorFactory[InternalRow, InternalRow] { + extends PartitionEvaluatorFactory[InternalRow, InternalRow] { override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = new MapInBatchEvaluator @@ -70,7 +70,8 @@ class MapInBatchEvaluatorFactory( largeVarTypes, pythonRunnerConf, pythonMetrics, - jobArtifactUUID).compute(batchIter, context.partitionId(), context) + jobArtifactUUID, + None).compute(batchIter, context.partitionId(), context) val unsafeProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 8db389f02667a..346a3a2ca354e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -46,8 +46,9 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { override protected def doExecute(): RDD[InternalRow] = { val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - val pythonFunction = func.asInstanceOf[PythonUDF].func - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + val pythonUDF = func.asInstanceOf[PythonUDF] + val pythonFunction = pythonUDF.func + val chainedFunc = Seq((ChainedPythonFunctions(Seq(pythonFunction)), pythonUDF.resultId.id)) val evaluatorFactory = new MapInBatchEvaluatorFactory( output, chainedFunc, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 167a96ed41c72..bbe9fbfc748db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -30,17 +30,17 @@ import org.apache.spark.sql.internal.SQLConf * A helper class to run Python UDFs in Spark. */ abstract class BasePythonUDFRunner( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, argOffsets: Array[Array[Int]], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, evalType, argOffsets, jobArtifactUUID) { + funcs.map(_._1), evalType, argOffsets, jobArtifactUUID) { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( - funcs.head.funcs.head.pythonExec) + funcs.head._1.funcs.head.pythonExec) override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback @@ -112,29 +112,31 @@ abstract class BasePythonUDFRunner( } class PythonUDFRunner( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, argOffsets: Array[Array[Int]], pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + profiler: Option[String]) extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, jobArtifactUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler) } } class PythonUDFWithNamedArgumentsRunner( - funcs: Seq[ChainedPythonFunctions], + funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, argMetas: Array[Array[ArgumentMetadata]], pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + profiler: Option[String]) extends BasePythonUDFRunner( funcs, evalType, argMetas.map(_.map(_.offset)), pythonMetrics, jobArtifactUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas) + PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas, profiler) } } @@ -142,10 +144,17 @@ object PythonUDFRunner { def writeUDFs( dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]]): Unit = { + funcs: Seq[(ChainedPythonFunctions, Long)], + argOffsets: Array[Array[Int]], + profiler: Option[String]): Unit = { + profiler match { + case Some(p) => + dataOut.writeBoolean(true) + PythonWorkerUtils.writeUTF(p, dataOut) + case _ => dataOut.writeBoolean(false) + } dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => + funcs.zip(argOffsets).foreach { case ((chained, resultId), offsets) => dataOut.writeInt(offsets.length) offsets.foreach { offset => dataOut.writeInt(offset) @@ -154,15 +163,25 @@ object PythonUDFRunner { chained.funcs.foreach { f => PythonWorkerUtils.writePythonFunction(f, dataOut) } + if (profiler.isDefined) { + dataOut.writeLong(resultId) + } } } def writeUDFs( dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argMetas: Array[Array[ArgumentMetadata]]): Unit = { + funcs: Seq[(ChainedPythonFunctions, Long)], + argMetas: Array[Array[ArgumentMetadata]], + profiler: Option[String]): Unit = { + profiler match { + case Some(p) => + dataOut.writeBoolean(true) + PythonWorkerUtils.writeUTF(p, dataOut) + case _ => dataOut.writeBoolean(false) + } dataOut.writeInt(funcs.length) - funcs.zip(argMetas).foreach { case (chained, metas) => + funcs.zip(argMetas).foreach { case ((chained, resultId), metas) => dataOut.writeInt(metas.length) metas.foreach { case ArgumentMetadata(offset, name) => @@ -179,6 +198,9 @@ object PythonUDFRunner { chained.funcs.foreach { f => PythonWorkerUtils.writePythonFunction(f, dataOut) } + if (profiler.isDefined) { + dataOut.writeLong(resultId) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala index 12d484b12dacf..e7fc9c7391af4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala @@ -43,7 +43,8 @@ class WindowInPandasEvaluatorFactory( val orderSpec: Seq[SortOrder], val childOutput: Seq[Attribute], val spillSize: SQLMetric, - pythonMetrics: Map[String, SQLMetric]) + pythonMetrics: Map[String, SQLMetric], + profiler: Option[String]) extends PartitionEvaluatorFactory[InternalRow, InternalRow] with WindowEvaluatorFactoryBase { /** @@ -69,15 +70,15 @@ class WindowInPandasEvaluatorFactory( private val windowBoundTypeConf = "pandas_window_bound_types" private def collectFunctions( - udf: PythonFuncExpression): (ChainedPythonFunctions, Seq[Expression]) = { + udf: PythonFuncExpression): ((ChainedPythonFunctions, Long), Seq[Expression]) = { udf.children match { case Seq(u: PythonFuncExpression) => - val (chained, children) = collectFunctions(u) - (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + val ((chained, _), children) = collectFunctions(u) + ((ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), udf.resultId.id), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. assert(children.forall(!_.exists(_.isInstanceOf[PythonFuncExpression]))) - (ChainedPythonFunctions(Seq(udf.func)), udf.children) + ((ChainedPythonFunctions(Seq(udf.func)), udf.resultId.id), udf.children) } } @@ -368,7 +369,8 @@ class WindowInPandasEvaluatorFactory( largeVarTypes, pythonRunnerConf, pythonMetrics, - jobArtifactUUID).compute(pythonInput, context.partitionId(), context) + jobArtifactUUID, + profiler).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index ee0044162b9a1..c0a38eadbe642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -87,7 +87,8 @@ case class WindowInPandasExec( orderSpec, child.output, longMetric("spillSize"), - pythonMetrics) + pythonMetrics, + None) // TODO(SPARK-46691): Support profiling on WindowInPandasExec // Start processing. if (conf.usePartitionEvaluator) {