Skip to content

Commit b8ffa50

Browse files
committed
added test and fix for chained pandas_udf
1 parent 53926cc commit b8ffa50

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

python/pyspark/sql/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3300,6 +3300,24 @@ def test_vectorized_udf_mix_udf(self):
33003300
'Can not mix vectorized and non-vectorized UDFs'):
33013301
df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()
33023302

3303+
def test_vectorized_udf_chained(self):
3304+
from pyspark.sql.functions import pandas_udf, col
3305+
df = self.spark.range(10).toDF('x')
3306+
f = pandas_udf(lambda x: x + 1, LongType())
3307+
g = pandas_udf(lambda x: x - 1, LongType())
3308+
res = df.select(g(f(col('x'))))
3309+
self.assertEquals(df.collect(), res.collect())
3310+
3311+
def test_vectorized_udf_wrong_return_type(self):
3312+
from pyspark.sql.functions import pandas_udf, col
3313+
df = self.spark.range(10).toDF('x')
3314+
f = pandas_udf(lambda x: x * 1.0, StringType())
3315+
with QuietTest(self.sc):
3316+
with self.assertRaisesRegexp(
3317+
Exception,
3318+
'Invalid.*type.*string'):
3319+
df.select(f(col('x'))).collect()
3320+
33033321

33043322
if __name__ == "__main__":
33053323
from pyspark.sql.tests import *

python/pyspark/worker.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,12 @@ def read_command(serializer, file):
6060
return command
6161

6262

63-
def chain(f, g):
64-
"""chain two function together """
65-
return lambda *a: g(f(*a))
63+
def chain(f, g, eval_type):
64+
"""chain two functions together """
65+
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
66+
return lambda *a, **kwargs: g(f(*a, **kwargs), **kwargs)
67+
else:
68+
return lambda *a: g(f(*a))
6669

6770

6871
def wrap_udf(f, return_type):
@@ -96,7 +99,7 @@ def read_single_udf(pickleSer, infile, eval_type):
9699
if row_func is None:
97100
row_func = f
98101
else:
99-
row_func = chain(row_func, f)
102+
row_func = chain(row_func, f, eval_type)
100103
# the last returnType will be the return type of UDF
101104
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
102105
# A pandas_udf will take kwargs as the last argument

0 commit comments

Comments
 (0)