Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ private[spark] object PythonEvalType {
val NON_UDF = 0
val SQL_BATCHED_UDF = 1
val SQL_PANDAS_UDF = 2
val SQL_PANDAS_GROUPED_UDF = 3
}

/**
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class PythonEvalType(object):
NON_UDF = 0
SQL_BATCHED_UDF = 1
SQL_PANDAS_UDF = 2
SQL_PANDAS_GROUPED_UDF = 3


class Serializer(object):
Expand Down
33 changes: 21 additions & 12 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,13 +2038,22 @@ def _wrap_function(sc, func, returnType):
sc.pythonVer, broadcast_vars, sc._javaAccumulator)


class PythonUdfType(object):
# row-at-a-time UDFs
NORMAL_UDF = 0
# scalar vectorized UDFs
PANDAS_UDF = 1
# grouped vectorized UDFs
PANDAS_GROUPED_UDF = 2


class UserDefinedFunction(object):
"""
User defined function in Python

.. versionadded:: 1.3
"""
def __init__(self, func, returnType, name=None, vectorized=False):
def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF):
if not callable(func):
raise TypeError(
"Not a function or callable (__call__ is not defined): "
Expand All @@ -2058,7 +2067,7 @@ def __init__(self, func, returnType, name=None, vectorized=False):
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self.vectorized = vectorized
self.pythonUdfType = pythonUdfType

@property
def returnType(self):
Expand Down Expand Up @@ -2090,7 +2099,7 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.vectorized)
self._name, wrapped_func, jdt, self.pythonUdfType)
return judf

def __call__(self, *cols):
Expand Down Expand Up @@ -2121,33 +2130,33 @@ def wrapper(*args):

wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.vectorized = self.vectorized
wrapper.pythonUdfType = self.pythonUdfType

return wrapper


def _create_udf(f, returnType, vectorized):
def _create_udf(f, returnType, pythonUdfType):

def _udf(f, returnType=StringType(), vectorized=vectorized):
if vectorized:
def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType):
if pythonUdfType == PythonUdfType.PANDAS_UDF:
import inspect
argspec = inspect.getargspec(f)
if len(argspec.args) == 0 and argspec.varargs is None:
raise ValueError(
"0-arg pandas_udfs are not supported. "
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
)
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType)
return udf_obj._wrapped()

# decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
if f is None or isinstance(f, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(_udf, returnType=return_type, vectorized=vectorized)
return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType)
else:
return _udf(f=f, returnType=returnType, vectorized=vectorized)
return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType)


@since(1.3)
Expand Down Expand Up @@ -2181,7 +2190,7 @@ def udf(f=None, returnType=StringType()):
| 8| JOHN DOE| 22|
+----------+--------------+------------+
"""
return _create_udf(f, returnType=returnType, vectorized=False)
return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF)


@since(2.3)
Expand Down Expand Up @@ -2252,7 +2261,7 @@ def pandas_udf(f=None, returnType=StringType()):

.. note:: The user-defined function must be deterministic.
"""
return _create_udf(f, returnType=returnType, vectorized=True)
return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF)


blacklist = ['map', 'since', 'ignore_unicode_prefix']
Expand Down
14 changes: 9 additions & 5 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import PythonUdfType, UserDefinedFunction
from pyspark.sql.types import *

__all__ = ["GroupedData"]
Expand Down Expand Up @@ -235,11 +236,13 @@ def apply(self, udf):
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`

"""
from pyspark.sql.functions import pandas_udf
import inspect

# Columns are special because hasattr always return True
if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized:
raise ValueError("The argument to apply must be a pandas_udf")
if isinstance(udf, Column) or not hasattr(udf, 'func') \
or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \
or len(inspect.getargspec(udf.func).args) != 1:
raise ValueError("The argument to apply must be a 1-arg pandas_udf")
if not isinstance(udf.returnType, StructType):
raise ValueError("The returnType of the pandas_udf must be a StructType")

Expand Down Expand Up @@ -268,8 +271,9 @@ def wrapped(*cols):
return [(result[result.columns[i]], arrow_type)
for i, arrow_type in enumerate(arrow_return_types)]

wrapped_udf_obj = pandas_udf(wrapped, returnType)
udf_column = wrapped_udf_obj(*[df[col] for col in df.columns])
udf_obj = UserDefinedFunction(
wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF)
udf_column = udf_obj(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)

Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3383,6 +3383,15 @@ def test_vectorized_udf_varargs(self):
res = df.select(f(col('id')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType([StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
f = pandas_udf(lambda x: x, DateType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.select(f(col('dt'))).collect()


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyApplyTests(ReusedPySparkTestCase):
Expand Down Expand Up @@ -3492,6 +3501,18 @@ def normalize(pdf):
expected = expected.assign(norm=expected.norm.astype('float64'))
self.assertFramesEqual(expected, result)

def test_datatype_string(self):
from pyspark.sql.functions import pandas_udf
df = self.data

foo_udf = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
"id long, v int, v1 double, v2 long")

result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
self.assertFramesEqual(expected, result)

def test_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf
df = self.data
Expand All @@ -3517,9 +3538,25 @@ def test_wrong_args(self):
df.groupby('id').apply(sum(df.v))
with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
df.groupby('id').apply(df.v + 1)
with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
df.groupby('id').apply(
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
df.groupby('id').apply(
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())])))
with self.assertRaisesRegexp(ValueError, 'returnType'):
df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType()))

def test_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType(
[StructField("id", LongType(), True), StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema)
f = pandas_udf(lambda x: x, df.schema)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.groupby('id').apply(f).collect()


if __name__ == "__main__":
from pyspark.sql.tests import *
Expand Down
39 changes: 17 additions & 22 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type, StructType
from pyspark.sql.types import to_arrow_type
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -74,28 +74,19 @@ def wrap_udf(f, return_type):


def wrap_pandas_udf(f, return_type):
# If the return_type is a StructType, it indicates this is a groupby apply udf,
# and has already been wrapped under apply(), otherwise, it's a vectorized column udf.
# We can distinguish these two by return type because in groupby apply, we always specify
# returnType as a StructType, and in vectorized column udf, StructType is not supported.
#
# TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs
if isinstance(return_type, StructType):
return lambda *a: f(*a)
else:
arrow_return_type = to_arrow_type(return_type)
arrow_return_type = to_arrow_type(return_type)

def verify_result_length(*a):
result = f(*a)
if not hasattr(result, "__len__"):
raise TypeError("Return type of the user-defined functon should be "
"Pandas.Series, but is {}".format(type(result)))
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
"expected %d, got %d" % (len(a[0]), len(result)))
return result
def verify_result_length(*a):
result = f(*a)
if not hasattr(result, "__len__"):
raise TypeError("Return type of the user-defined functon should be "
"Pandas.Series, but is {}".format(type(result)))
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
"expected %d, got %d" % (len(a[0]), len(result)))
return result

return lambda *a: (verify_result_length(*a), arrow_return_type)
return lambda *a: (verify_result_length(*a), arrow_return_type)


def read_single_udf(pickleSer, infile, eval_type):
Expand All @@ -111,6 +102,9 @@ def read_single_udf(pickleSer, infile, eval_type):
# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
return arg_offsets, wrap_pandas_udf(row_func, return_type)
elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF:
# a groupby apply udf has already been wrapped under apply()
return arg_offsets, row_func
else:
return arg_offsets, wrap_udf(row_func, return_type)

Expand All @@ -133,7 +127,8 @@ def read_udfs(pickleSer, infile, eval_type):

func = lambda _, it: map(mapper, it)

if eval_type == PythonEvalType.SQL_PANDAS_UDF:
if eval_type == PythonEvalType.SQL_PANDAS_UDF \
or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF:
ser = ArrowStreamPandasSerializer()
else:
ser = BatchedSerializer(PickleSerializer(), 100)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre
* This is used by DataFrame.groupby().apply().
*/
case class FlatMapGroupsInPandas(
groupingAttributes: Seq[Attribute],
functionExpr: Expression,
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
groupingAttributes: Seq[Attribute],
functionExpr: Expression,
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {

/**
* This is needed because output attributes are considered `references` when
* passed through the constructor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.python.PythonUDF
import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{NumericType, StructField, StructType}

Expand Down Expand Up @@ -437,7 +437,7 @@ class RelationalGroupedDataset protected[sql](
}

/**
* Applies a vectorized python user-defined function to each group of data.
* Applies a grouped vectorized python user-defined function to each group of data.
* The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`.
* For each group, all elements in the group are passed as a `pandas.DataFrame` and the results
* for all groups are combined into a new [[DataFrame]].
Expand All @@ -449,7 +449,8 @@ class RelationalGroupedDataset protected[sql](
* workers.
*/
private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = {
require(expr.vectorized, "Must pass a vectorized python udf")
require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF,
"Must pass a grouped vectorized python udf")
require(expr.dataType.isInstanceOf[StructType],
"The returnType of the vectorized python udf must be a StructType")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
udf.references.subsetOf(child.outputSet)
}
if (validUdfs.nonEmpty) {
if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) {
throw new IllegalArgumentException("Can not use grouped vectorized UDFs")
}

val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
AttributeReference(s"pythonUDF$i", u.dataType)()
}

val evaluation = validUdfs.partition(_.vectorized) match {
val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match {
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec(

val columnarBatchIter = new ArrowPythonRunner(
chainedFunc, bufferSize, reuseWorker,
PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema)
PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema)
.compute(grouped, context.partitionId(), context)

columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression],
vectorized: Boolean)
pythonUdfType: Int)
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {

override def toString: String = s"$name(${children.mkString(", ")})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,26 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.DataType

private[spark] object PythonUdfType {
// row-at-a-time UDFs
val NORMAL_UDF = 0
// scalar vectorized UDFs
val PANDAS_UDF = 1
// grouped vectorized UDFs
val PANDAS_GROUPED_UDF = 2
}

/**
* A user-defined Python function. This is used by the Python API.
*/
case class UserDefinedPythonFunction(
name: String,
func: PythonFunction,
dataType: DataType,
vectorized: Boolean) {
pythonUdfType: Int) {

def builder(e: Seq[Expression]): PythonUDF = {
PythonUDF(name, func, dataType, e, vectorized)
PythonUDF(name, func, dataType, e, pythonUdfType)
}

/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
name = "dummyUDF",
func = new DummyUDF,
dataType = BooleanType,
vectorized = false)
pythonUdfType = PythonUdfType.NORMAL_UDF)