Skip to content

Commit 91885e5

Browse files
committed
Address comments
1 parent 4d22107 commit 91885e5

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

python/pyspark/sql/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2226,7 +2226,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
22262226
3. GROUP_AGG
22272227
22282228
A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
2229-
The `returnType` should be a primitive data type, e.g, :class:`DoubleType`.
2229+
The `returnType` should be a primitive data type, e.g., :class:`DoubleType`.
22302230
The returned scalar can be either a python primitive type, e.g., `int` or `float`
22312231
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
22322232

python/pyspark/sql/udf.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ def _wrap_function(sc, func, returnType):
3737

3838
def _create_udf(f, returnType, evalType):
3939

40-
if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \
41-
evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
40+
if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF,
41+
PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
42+
PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF):
43+
4244
import inspect
4345
from pyspark.sql.utils import require_minimum_pyarrow_version
4446

0 commit comments

Comments
 (0)