diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 1ec0e717fac29..29148a7ee558b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -39,12 +39,14 @@ private[spark] object PythonEvalType { val SQL_PANDAS_SCALAR_UDF = 200 val SQL_PANDAS_GROUP_MAP_UDF = 201 + val SQL_PANDAS_GROUP_AGG_UDF = 202 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" + case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 340bc3a6b7470..bfeb034e8e242 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,7 @@ class PythonEvalType(object): SQL_PANDAS_SCALAR_UDF = 200 SQL_PANDAS_GROUP_MAP_UDF = 201 + SQL_PANDAS_GROUP_AGG_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 961b3267b44cf..a291c9b71913f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2089,6 +2089,8 @@ class PandasUDFType(object): GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF + GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + @since(1.3) def udf(f=None, returnType=StringType()): @@ -2159,7 +2161,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., `DoubleType()`. + The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and @@ -2221,6 +2223,35 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + 3. GROUP_AGG + + A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. + The returned scalar can be either a python primitive type, e.g., `int` or `float` + or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. + + :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as + output types. + + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP + +---+-----------+ + | id|mean_udf(v)| + +---+-----------+ + | 1| 1.5| + | 2| 6.0| + +---+-----------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. If your function is not deterministic, call @@ -2267,7 +2298,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): raise ValueError("Invalid returnType: returnType can not be None") if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 22061b83eb78c..f90a909d7c2b1 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -65,13 +65,27 @@ def __init__(self, jgd, df): def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + The available aggregate functions can be: + + 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` + + 2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf` + + .. note:: There is no partial aggregation with group aggregate UDFs, i.e., + a full shuffle is required. Also, all the data of a group will be loaded into + memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. seealso:: :func:`pyspark.sql.functions.pandas_udf` If ``exprs`` is a single :class:`dict` mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function. Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed + in a single call to this function. + :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. @@ -82,6 +96,13 @@ def agg(self, *exprs): >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP + ... def min_udf(v): + ... return v.min() + >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP + [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -204,16 +225,18 @@ def apply(self, udf): The user-defined function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame`s are combined as a + to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. - This function does not support partial aggregation, and requires shuffling all the data in - the :class:`DataFrame`. + .. note:: This function requires a full shuffle. all the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. :param udf: a group map user-defined function returned by - :meth:`pyspark.sql.functions.pandas_udf`. + :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f84aa3d68b808..095828bdfb051 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -197,6 +197,12 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + def assertPandasEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + + "\n\nResult:\n%s\n%s" % (result, result.dtypes)) + self.assertTrue(expected.equals(result), msg=msg) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 @@ -3354,12 +3360,6 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - def create_pandas_data_frame(self): import pandas as pd import numpy as np @@ -3397,7 +3397,7 @@ def _toPandas_arrow_toggle(self, df): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -3408,11 +3408,11 @@ def test_toPandas_respect_session_timezone(self): self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") try: pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_la, pdf_la) + self.assertPandasEqual(pdf_arrow_la, pdf_la) finally: self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_ny, pdf_ny) + self.assertPandasEqual(pdf_arrow_ny, pdf_ny) self.assertFalse(pdf_ny.equals(pdf_la)) @@ -3422,7 +3422,7 @@ def test_toPandas_respect_session_timezone(self): if isinstance(field.dataType, TimestampType): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) - self.assertFramesEqual(pdf_ny, pdf_la_corrected) + self.assertPandasEqual(pdf_ny, pdf_la_corrected) finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) @@ -3430,7 +3430,7 @@ def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_filtered_frame(self): df = self.spark.range(3).toDF("i") @@ -3488,7 +3488,7 @@ def test_createDataFrame_with_schema(self): df = self.spark.createDataFrame(pdf, schema=self.schema) self.assertEquals(self.schema, df.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() @@ -3700,7 +3700,7 @@ def foo(k, v): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class VectorizedUDFTests(ReusedSQLTestCase): +class ScalarPandasUDF(ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -4179,13 +4179,7 @@ def test_register_vectorized_udf_basic(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyTests(ReusedSQLTestCase): - - def assertFramesEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + - ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) - self.assertTrue(expected.equals(result), msg=msg) +class GroupbyApplyPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -4210,7 +4204,7 @@ def test_simple(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_register_group_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4234,7 +4228,7 @@ def foo(pdf): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_coerce(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4249,7 +4243,7 @@ def test_coerce(self): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) expected = expected.assign(v=expected.v.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_complex_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4268,7 +4262,7 @@ def normalize(pdf): expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_empty_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4287,7 +4281,7 @@ def normalize(pdf): expected = normalize.func(pdf) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_datatype_string(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4301,7 +4295,7 @@ def test_datatype_string(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4353,6 +4347,446 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyAggPandasUDFTests(ReusedSQLTestCase): + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + + @udf('double') + def plus_one(v): + assert isinstance(v, (int, float)) + return v + 1 + return plus_one + + @property + def pandas_scalar_plus_two(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.SCALAR) + def plus_two(v): + assert isinstance(v, pd.Series) + return v + 2 + return plus_two + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_sum_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def sum(v): + return v.sum() + return sum + + @property + def pandas_agg_weighted_mean_udf(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def weighted_mean(v, w): + return np.average(v, weights=w) + return weighted_mean + + def test_manual(self): + df = self.data + sum_udf = self.pandas_agg_sum_udf + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id') + expected1 = self.spark.createDataFrame( + [[0, 245.0, 24.5], + [1, 255.0, 25.5], + [2, 265.0, 26.5], + [3, 275.0, 27.5], + [4, 285.0, 28.5], + [5, 295.0, 29.5], + [6, 305.0, 30.5], + [7, 315.0, 31.5], + [8, 325.0, 32.5], + [9, 335.0, 33.5]], + ['id', 'sum(v)', 'avg(v)']) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_basic(self): + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + # Groupby one column and aggregate one UDF with literal + result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') + expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + # Groupby one expression and aggregate one UDF with literal + result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ + .sort(df.id + 1) + expected2 = df.groupby((col('id') + 1))\ + .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + # Groupby one column and aggregate one UDF without literal + result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') + expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id') + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + # Groupby one expression and aggregate one UDF without literal + result4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(weighted_mean_udf(df.v, df.w))\ + .sort('id') + expected4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(mean(df.v).alias('weighted_mean(v, w)'))\ + .sort('id') + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_unsupported_types(self): + from pyspark.sql.types import ArrayType, DoubleType, MapType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return [v.mean(), v.std()] + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return v.mean(), v.std() + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return {v.mean(): v.std()} + + def test_alias(self): + from pyspark.sql.functions import mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + """ + Test mixing group aggregate pandas UDF with sql expression. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF with sql expression + result1 = (df.groupby('id') + .agg(sum_udf(df.v) + 1) + .sort('id')) + expected1 = (df.groupby('id') + .agg(sum(df.v) + 1) + .sort('id')) + + # Mix group aggregate pandas UDF with sql expression (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(df.v + 1)) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1)) + .sort('id')) + + # Wrap group aggregate pandas UDF with two sql expressions + result3 = (df.groupby('id') + .agg(sum_udf(df.v + 1) + 2) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(df.v + 1) + 2) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + def test_mixed_udfs(self): + """ + Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF and python UDF + result1 = (df.groupby('id') + .agg(plus_one(sum_udf(df.v))) + .sort('id')) + expected1 = (df.groupby('id') + .agg(plus_one(sum(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and python UDF (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(plus_one(df.v))) + .sort('id')) + expected2 = (df.groupby('id') + .agg(sum(plus_one(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF + result3 = (df.groupby('id') + .agg(sum_udf(plus_two(df.v))) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(plus_two(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped) + result4 = (df.groupby('id') + .agg(plus_two(sum_udf(df.v))) + .sort('id')) + expected4 = (df.groupby('id') + .agg(plus_two(sum(df.v))) + .sort('id')) + + # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby + result5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum_udf(plus_one(df.v)))) + .sort('plus_one(id)')) + expected5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum(plus_one(df.v)))) + .sort('plus_one(id)')) + + # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in + # groupby + result6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum_udf(plus_two(df.v)))) + .sort('plus_two(id)')) + expected6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum(plus_two(df.v)))) + .sort('plus_two(id)')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + + def test_multiple_udfs(self): + """ + Test multiple group aggregate pandas UDFs in one agg function. + """ + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + sum_udf = self.pandas_agg_sum_udf + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + result1 = (df.groupBy('id') + .agg(mean_udf(df.v), + sum_udf(df.v), + weighted_mean_udf(df.v, df.w)) + .sort('id') + .toPandas()) + expected1 = (df.groupBy('id') + .agg(mean(df.v), + sum(df.v), + mean(df.v).alias('weighted_mean(v, w)')) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + + def test_complex_groupby(self): + from pyspark.sql.functions import lit, sum + + df = self.data + sum_udf = self.pandas_agg_sum_udf + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + + # groupby one expression + result1 = df.groupby(df.v % 2).agg(sum_udf(df.v)) + expected1 = df.groupby(df.v % 2).agg(sum(df.v)) + + # empty groupby + result2 = df.groupby().agg(sum_udf(df.v)) + expected2 = df.groupby().agg(sum(df.v)) + + # groupby one column and one sql expression + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + + # groupby one python UDF + result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) + expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)) + + # groupby one scalar pandas UDF + result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) + expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)) + + # groupby one expression and one python UDF + result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v)) + expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)) + + # groupby one expression and one scalar pandas UDF + result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)') + expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)') + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) + + def test_complex_expressions(self): + from pyspark.sql.functions import col, sum + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Test complex expressions with sql expression, python UDF and + # group aggregate pandas UDF + result1 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_one(sum_udf(col('v1'))), + sum_udf(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + expected1 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_one(sum(col('v1'))), + sum(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + # Test complex expressions with sql expression, scala pandas UDF and + # group aggregate pandas UDF + result2 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_two(sum_udf(col('v1'))), + sum_udf(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + expected2 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_two(sum(col('v1'))), + sum(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + # Test sequential groupby aggregate + result3 = (df.groupby('id') + .agg(sum_udf(df.v).alias('v')) + .groupby('id') + .agg(sum_udf(col('v'))) + .sort('id') + .toPandas()) + + expected3 = (df.groupby('id') + .agg(sum(df.v).alias('v')) + .groupby('id') + .agg(sum(col('v'))) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) + + def test_retain_group_columns(self): + from pyspark.sql.functions import sum, lit, col + orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) + self.spark.conf.set("spark.sql.retainGroupColumns", False) + try: + df = self.data + sum_udf = self.pandas_agg_sum_udf + + result1 = df.groupby(df.id).agg(sum_udf(df.v)) + expected1 = df.groupby(df.id).agg(sum(df.v)) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + finally: + if orig_value is None: + self.spark.conf.unset("spark.sql.retainGroupColumns") + else: + self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) + + def test_invalid_args(self): + from pyspark.sql.functions import mean + + df = self.data + plus_one = self.python_plus_one + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'nor.*aggregate function'): + df.groupby(df.id).agg(plus_one(df.v)).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'aggregate function.*argument.*aggregate function'): + df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'mixture.*aggregate function.*group aggregate pandas UDF'): + df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c77f19f89a442..094b8c663737e 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -22,7 +22,8 @@ from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ + _parse_datatype_string __all__ = ["UDFRegistration"] @@ -36,8 +37,10 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ - evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + import inspect from pyspark.sql.utils import require_minimum_pyarrow_version @@ -113,6 +116,10 @@ def returnType(self): and not isinstance(self._returnType_placeholder, StructType): raise ValueError("Invalid returnType: returnType must be a StructType for " "pandas_udf with function type GROUP_MAP") + elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ + and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): + raise NotImplementedError( + "ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG") return self._returnType_placeholder diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e6737ae1c1285..173d8fb2856fa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -110,6 +110,17 @@ def wrapped(*series): return wrapped +def wrap_pandas_group_agg_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series(result) + + return lambda *a: (wrapped(*a), arrow_return_type) + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -126,8 +137,12 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) - else: + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF: + return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) + else: + raise ValueError("Unknown eval type: {}".format(eval_type)) def read_udfs(pickleSer, infile, eval_type): @@ -148,8 +163,9 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bbcec5627bd49..ef91d79f3302c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -153,11 +153,19 @@ trait CheckAnalysis extends PredicateHelper { s"of type ${condition.dataType.simpleString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => + def isAggregateExpression(expr: Expression) = { + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + } + def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case aggExpr: AggregateExpression => - aggExpr.aggregateFunction.children.foreach { child => + case expr: Expression if isAggregateExpression(expr) => + val aggFunction = expr match { + case agg: AggregateExpression => agg.aggregateFunction + case udf: PythonUDF => udf + } + aggFunction.children.foreach { child => child.foreach { - case agg: AggregateExpression => + case expr: Expression if isAggregateExpression(expr) => failAnalysis( s"It is not allowed to use an aggregate function in the argument of " + s"another aggregate function. Please use the inner aggregate function " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala similarity index 60% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index d3f743d9eb61e..4ba8ff6e3802f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -15,12 +15,31 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.python +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression} +import org.apache.spark.api.python.{PythonEvalType, PythonFunction} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.DataType +/** + * Helper functions for [[PythonUDF]] + */ +object PythonUDF { + private[this] val SCALAR_TYPES = Set( + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF + ) + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) + } + + def isGroupAggPandasUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && + e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + } +} + /** * A serialized version of a Python lambda function. */ @@ -30,12 +49,16 @@ case class PythonUDF( dataType: DataType, children: Seq[Expression], evalType: Int, - udfDeterministic: Boolean) + udfDeterministic: Boolean, + resultId: ExprId = NamedExpression.newExprId) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"$name(${children.mkString(", ")})" + lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)( + exprId = resultId) + override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cc391aae55787..132241061d510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.planning +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -199,7 +200,7 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { object PhysicalAggregation { // groupingExpressions, aggregateExpressions, resultExpressions, child type ReturnType = - (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { case logical.Aggregate(groupingExpressions, resultExpressions, child) => @@ -213,7 +214,10 @@ object PhysicalAggregation { expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression - if (!equivalentAggregateExpressions.addExpr(agg)) => agg + if !equivalentAggregateExpressions.addExpr(agg) => agg + case udf: PythonUDF + if PythonUDF.isGroupAggPandasUDF(udf) && + !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -241,6 +245,10 @@ object PhysicalAggregation { // so replace each aggregate expression by its corresponding attribute in the set: equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute + // Similar to AggregateExpression + case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + equivalentAggregateExpressions.getEquivalentExprs(ue).headOption + .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index a009c00b0abc5..d320c1c359411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 910294853c318..ce512bc46563a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -288,9 +289,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + throw new AnalysisException( + "Streaming aggregation doesn't support group aggregate pandas UDF") + } + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, - aggregateExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, planLater(child)) @@ -333,8 +339,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalAggregation( - groupingExpressions, aggregateExpressions, resultExpressions, child) => + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + val aggregateExpressions = aggExpressions.map(expr => + expr.asInstanceOf[AggregateExpression]) val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -363,6 +371,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateOperator + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) => + val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) + + Seq(execution.python.AggregateInPandasExec( + groupingExpressions, + udfExpressions, + resultExpressions, + planLater(child))) + + case PhysicalAggregation(_, _, _, _) => + // If cannot match the two cases above, then it's an error + throw new AnalysisException( + "Cannot use a mixture of aggregate function and group aggregate pandas UDF") + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala new file mode 100644 index 0000000000000..18e5f8605c60d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +/** + * Physical node for aggregation with group aggregate Pandas UDF. + * + * This plan works by sending the necessary (projected) input grouped data as Arrow record batches + * to the python worker, the python worker invokes the UDF and sends the results to the executor, + * finally the executor evaluates any post-aggregation expressions and join the result with the + * grouped key. + */ +case class AggregateInPandasExec( + groupingExpressions: Seq[NamedExpression], + udfExpressions: Seq[PythonUDF], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode { + + override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingExpressions.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val aggInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val prunedProj = UnsafeProjection.create(allInputs, child.output) + + val grouped = if (groupingExpressions.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, groupingExpressions, child.output) + }.map { case (key, rows) => + (key, rows.map(prunedProj)) + } + + val context = TaskContext.get() + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + // Add rows to queue to join later with the result. + val projectedRowIter = grouped.map { case (groupingKey, rows) => + queue.add(groupingKey.asInstanceOf[UnsafeRow]) + rows + } + + val columnarBatchIter = new ArrowPythonRunner( + pyFuncs, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(projectedRowIter, context.partitionId(), context) + + val joinedAttributes = + groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) + + columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, aggOutputRow) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 2f53fe788c7d0..1862e3f6e12ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -39,12 +39,13 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || + PythonUDF.isGroupAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined }.isDefined } @@ -93,7 +94,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { - e.find(_.isInstanceOf[PythonUDF]).isDefined + e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { @@ -106,12 +107,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) case e => e.children.flatMap(collectEvaluatableUDF) } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // FlatMapGroupsInPandas can be evaluated directly in python worker + // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker // Therefore we don't need to extract the UDFs case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) @@ -149,10 +150,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - require(validUdfs.forall(udf => - udf.evalType == PythonEvalType.SQL_BATCHED_UDF || - udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF - ), "Can only extract scalar vectorized udf or sql batch udf") + require( + validUdfs.forall(PythonUDF.isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 50dca32cb7861..f4c2d02ee9420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF} import org.apache.spark.sql.types.DataType /**