diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5833734103a..ae34766ff0c1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2718,9 +2718,10 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + :class:`MapType`, :class:`StructType` are currently not supported as output types. + Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and :meth:`pyspark.sql.DataFrame.select`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 91ed600afedd..a739181e1d82 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4358,6 +4358,7 @@ def test_timestamp_dst(self): not _have_pandas or not _have_pyarrow, _pandas_requirement_message or _pyarrow_requirement_message) class PandasUDFTests(ReusedSQLTestCase): + def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4573,6 +4574,24 @@ def random_udf(v): random_udf = random_udf.asNondeterministic() return random_udf + def test_pandas_udf_tokenize(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')), + ArrayType(StringType())) + self.assertEqual(tokenize.returnType, ArrayType(StringType())) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect()) + + def test_pandas_udf_nested_arrays(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]), + ArrayType(ArrayType(StringType()))) + self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect()) + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select(