diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 977bff690bacc..290fe4d0398e8 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -549,13 +549,15 @@ class SparkConnectPlanner( pythonUdf, DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]), baseRel, - isBarrier) + isBarrier, + None) case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => logical.MapInArrow( pythonUdf, DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]), baseRel, - isBarrier) + isBarrier, + None) case _ => throw InvalidPlanInput( s"Function with EvalType: ${pythonUdf.evalType} is not supported") diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 508cf56b9c873..a8544cf14a808 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -530,6 +530,7 @@ def __hash__(self): "pyspark.sql.tests.test_udf_profiler", "pyspark.sql.tests.test_udtf", "pyspark.sql.tests.test_utils", + "pyspark.sql.tests.test_resources", ], ) diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py index 60b36672ca506..25548a8b39577 100644 --- a/python/pyspark/sql/pandas/map_ops.py +++ b/python/pyspark/sql/pandas/map_ops.py @@ -15,9 +15,13 @@ # limitations under the License. # import sys -from typing import Union, TYPE_CHECKING +from typing import Union, TYPE_CHECKING, Optional +from py4j.java_gateway import JavaObject + +from pyspark.resource.requests import ExecutorResourceRequests, TaskResourceRequests from pyspark.rdd import PythonEvalType +from pyspark.resource import ResourceProfile from pyspark.sql.types import StructType if TYPE_CHECKING: @@ -32,7 +36,11 @@ class PandasMapOpsMixin: """ def mapInPandas( - self, func: "PandasMapIterFunction", schema: Union[StructType, str], barrier: bool = False + self, + func: "PandasMapIterFunction", + schema: Union[StructType, str], + barrier: bool = False, + profile: Optional[ResourceProfile] = None, ) -> "DataFrame": """ Maps an iterator of batches in the current :class:`DataFrame` using a Python native @@ -65,6 +73,12 @@ def mapInPandas( .. versionadded: 3.5.0 + profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile + to be used for mapInPandas. + + .. versionadded: 4.0.0 + + Examples -------- >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) @@ -141,11 +155,17 @@ def mapInPandas( func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF ) # type: ignore[call-overload] udf_column = udf(*[self[col] for col in self.columns]) - jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier) + + jrp = self._build_java_profile(profile) + jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier, jrp) return DataFrame(jdf, self.sparkSession) def mapInArrow( - self, func: "ArrowMapIterFunction", schema: Union[StructType, str], barrier: bool = False + self, + func: "ArrowMapIterFunction", + schema: Union[StructType, str], + barrier: bool = False, + profile: Optional[ResourceProfile] = None, ) -> "DataFrame": """ Maps an iterator of batches in the current :class:`DataFrame` using a Python native @@ -175,6 +195,11 @@ def mapInArrow( .. versionadded: 3.5.0 + profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile + to be used for mapInArrow. + + .. versionadded: 4.0.0 + Examples -------- >>> import pyarrow # doctest: +SKIP @@ -220,9 +245,35 @@ def mapInArrow( func, returnType=schema, functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF ) # type: ignore[call-overload] udf_column = udf(*[self[col] for col in self.columns]) - jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier) + + jrp = self._build_java_profile(profile) + jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier, jrp) return DataFrame(jdf, self.sparkSession) + def _build_java_profile( + self, profile: Optional[ResourceProfile] = None + ) -> Optional[JavaObject]: + """Build the java ResourceProfile based on PySpark ResourceProfile""" + from pyspark.sql import DataFrame + + assert isinstance(self, DataFrame) + + jrp = None + if profile is not None: + if profile._java_resource_profile is not None: + jrp = profile._java_resource_profile + else: + jvm = self.sparkSession.sparkContext._jvm + assert jvm is not None + + builder = jvm.org.apache.spark.resource.ResourceProfileBuilder() + ereqs = ExecutorResourceRequests(jvm, profile._executor_resource_requests) + treqs = TaskResourceRequests(jvm, profile._task_resource_requests) + builder.require(ereqs._java_executor_resource_requests) + builder.require(treqs._java_task_resource_requests) + jrp = builder.build() + return jrp + def _test() -> None: import doctest diff --git a/python/pyspark/sql/tests/test_resources.py b/python/pyspark/sql/tests/test_resources.py new file mode 100644 index 0000000000000..9dfb14d9c37f7 --- /dev/null +++ b/python/pyspark/sql/tests/test_resources.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark import SparkContext, TaskContext +from pyspark.resource import TaskResourceRequests, ResourceProfileBuilder +from pyspark.sql import SparkSession +from pyspark.testing.sqlutils import ( + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) +from pyspark.testing.utils import ReusedPySparkTestCase + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, +) +class ResourceProfileTestsMixin(object): + def test_map_in_arrow_without_profile(self): + def func(iterator): + tc = TaskContext.get() + assert tc.cpus() == 1 + for batch in iterator: + yield batch + + df = self.spark.range(10) + df.mapInArrow(func, "id long").collect() + + def test_map_in_arrow_with_profile(self): + def func(iterator): + tc = TaskContext.get() + assert tc.cpus() == 3 + for batch in iterator: + yield batch + + df = self.spark.range(10) + + treqs = TaskResourceRequests().cpus(3) + rp = ResourceProfileBuilder().require(treqs).build + df.mapInArrow(func, "id long", False, rp).collect() + + def test_map_in_pandas_without_profile(self): + def func(iterator): + tc = TaskContext.get() + assert tc.cpus() == 1 + for batch in iterator: + yield batch + + df = self.spark.range(10) + df.mapInPandas(func, "id long").collect() + + def test_map_in_pandas_with_profile(self): + def func(iterator): + tc = TaskContext.get() + assert tc.cpus() == 3 + for batch in iterator: + yield batch + + df = self.spark.range(10) + + treqs = TaskResourceRequests().cpus(3) + rp = ResourceProfileBuilder().require(treqs).build + df.mapInPandas(func, "id long", False, rp).collect() + + +class ResourceProfileTests(ResourceProfileTestsMixin, ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + cls.sc = SparkContext("local-cluster[1, 4, 1024]", cls.__name__, conf=cls.conf()) + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + super(ResourceProfileTests, cls).tearDownClass() + cls.spark.stop() + + +if __name__ == "__main__": + from pyspark.sql.tests.test_resources import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 2017b2e8eef63..d696ff45b9b7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -382,13 +382,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] { newVersion.copyTagsFrom(oldVersion) Seq((oldVersion, newVersion)) - case oldVersion @ MapInPandas(_, output, _, _) + case oldVersion @ MapInPandas(_, output, _, _, _) if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => val newVersion = oldVersion.copy(output = output.map(_.newInstance())) newVersion.copyTagsFrom(oldVersion) Seq((oldVersion, newVersion)) - case oldVersion @ MapInArrow(_, output, _, _) + case oldVersion @ MapInArrow(_, output, _, _, _) if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => val newVersion = oldVersion.copy(output = output.map(_.newInstance())) newVersion.copyTagsFrom(oldVersion) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 2664809d4b266..01d5a1bdea6a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.resource.ResourceProfile import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.truncatedString @@ -77,7 +78,8 @@ case class MapInPandas( functionExpr: Expression, output: Seq[Attribute], child: LogicalPlan, - isBarrier: Boolean) extends UnaryNode { + isBarrier: Boolean, + profile: Option[ResourceProfile]) extends UnaryNode { override val producedAttributes = AttributeSet(output) @@ -93,7 +95,8 @@ case class MapInArrow( functionExpr: Expression, output: Seq[Attribute], child: LogicalPlan, - isBarrier: Boolean) extends UnaryNode { + isBarrier: Boolean, + profile: Option[ResourceProfile]) extends UnaryNode { override val producedAttributes = AttributeSet(output) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1dce989073260..ef1ecd33c0792 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -709,7 +709,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { pythonUdf, output, project, - false) + false, + None) val left = SubqueryAlias("temp0", mapInPandas) val right = SubqueryAlias("temp1", mapInPandas) val join = Join(left, right, Inner, None, JoinHint.NONE) @@ -729,7 +730,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { pythonUdf, output, project, - false) + false, + None) assertAnalysisSuccess(mapInPandas) } @@ -745,7 +747,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { pythonUdf, output, project, - false) + false, + None) assertAnalysisSuccess(mapInArrow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 974b2dd7c7c10..189be1d6a30d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -37,6 +37,7 @@ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.api.r.RRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.resource.ResourceProfile import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation @@ -3515,14 +3516,18 @@ class Dataset[T] private[sql]( * This function uses Apache Arrow as serialization format between Java executors and Python * workers. */ - private[sql] def mapInPandas(func: PythonUDF, isBarrier: Boolean = false): DataFrame = { + private[sql] def mapInPandas( + func: PythonUDF, + isBarrier: Boolean = false, + profile: ResourceProfile = null): DataFrame = { Dataset.ofRows( sparkSession, MapInPandas( func, toAttributes(func.dataType.asInstanceOf[StructType]), logicalPlan, - isBarrier)) + isBarrier, + Option(profile))) } /** @@ -3530,14 +3535,18 @@ class Dataset[T] private[sql]( * defines a transformation: `iter(pyarrow.RecordBatch)` -> `iter(pyarrow.RecordBatch)`. * Each partition is each iterator consisting of `pyarrow.RecordBatch`s as batches. */ - private[sql] def mapInArrow(func: PythonUDF, isBarrier: Boolean = false): DataFrame = { + private[sql] def mapInArrow( + func: PythonUDF, + isBarrier: Boolean = false, + profile: ResourceProfile = null): DataFrame = { Dataset.ofRows( sparkSession, MapInArrow( func, toAttributes(func.dataType.asInstanceOf[StructType]), logicalPlan, - isBarrier)) + isBarrier, + Option(profile))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5d4063d125c8f..476088153ab69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -866,10 +866,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.python.FlatMapCoGroupsInArrowExec( f.leftAttributes, f.rightAttributes, func, output, planLater(left), planLater(right)) :: Nil - case logical.MapInPandas(func, output, child, isBarrier) => - execution.python.MapInPandasExec(func, output, planLater(child), isBarrier) :: Nil - case logical.MapInArrow(func, output, child, isBarrier) => - execution.python.MapInArrowExec(func, output, planLater(child), isBarrier) :: Nil + case logical.MapInPandas(func, output, child, isBarrier, profile) => + execution.python.MapInPandasExec(func, output, planLater(child), isBarrier, profile) :: Nil + case logical.MapInArrow(func, output, child, isBarrier, profile) => + execution.python.MapInArrowExec(func, output, planLater(child), isBarrier, profile) :: Nil case logical.AttachDistributedSequence(attr, child) => execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInArrowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInArrowExec.scala index 2b1d1928ffd2c..4c1fecd02272c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInArrowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInArrowExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.resource.ResourceProfile import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan @@ -29,7 +30,8 @@ case class MapInArrowExec( func: Expression, output: Seq[Attribute], child: SparkPlan, - override val isBarrier: Boolean) + override val isBarrier: Boolean, + override val profile: Option[ResourceProfile]) extends MapInBatchExec { override protected val pythonEvalType: Int = PythonEvalType.SQL_MAP_ARROW_ITER_UDF diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 346a3a2ca354e..096e9d7d16420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.JobArtifactSet import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.rdd.RDD +import org.apache.spark.resource.ResourceProfile import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -38,6 +39,8 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { protected val isBarrier: Boolean + protected val profile: Option[ResourceProfile] + override def producedAttributes: AttributeSet = AttributeSet(output) private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) @@ -61,7 +64,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { pythonMetrics, jobArtifactUUID) - if (isBarrier) { + val rdd = if (isBarrier) { val rddBarrier = child.execute().barrier() if (conf.usePartitionEvaluator) { rddBarrier.mapPartitionsWithEvaluator(evaluatorFactory) @@ -80,5 +83,6 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { } } } + profile.map(rp => rdd.withResources(rp)).getOrElse(rdd) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala index cfd97b6f49794..5dd461ff4c484 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.resource.ResourceProfile import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan @@ -29,7 +30,8 @@ case class MapInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan, - override val isBarrier: Boolean) + override val isBarrier: Boolean, + override val profile: Option[ResourceProfile]) extends MapInBatchExec { override protected val pythonEvalType: Int = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF