Skip to content

Commit 470cacd

Browse files
e-dorigattiHyukjinKwon
authored andcommitted
[SPARK-23754][PYTHON][FOLLOWUP][BACKPORT-2.3] Move UDF stop iteration wrapping from driver to executor
SPARK-23754 was fixed in #21383 by changing the UDF code to wrap the user function, but this required a hack to save its argspec. This PR reverts this change and fixes the `StopIteration` bug in the worker. The root of the problem is that when an user-supplied function raises a `StopIteration`, pyspark might stop processing data, if this function is used in a for-loop. The solution is to catch `StopIteration`s exceptions and re-raise them as `RuntimeError`s, so that the execution fails and the error is reported to the user. This is done using the `fail_on_stopiteration` wrapper, in different ways depending on where the function is used: - In RDDs, the user function is wrapped in the driver, because this function is also called in the driver itself. - In SQL UDFs, the function is wrapped in the worker, since all processing happens there. Moreover, the worker needs the signature of the user function, which is lost when wrapping it, but passing this signature to the worker requires a not so nice hack. HyukjinKwon Author: edorigatti <emilio.dorigatti@gmail.com> Author: e-dorigatti <emilio.dorigatti@gmail.com> Closes #21538 from e-dorigatti/branch-2.3.
1 parent a55de38 commit 470cacd

File tree

5 files changed

+70
-38
lines changed

5 files changed

+70
-38
lines changed

python/pyspark/sql/tests.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -853,22 +853,6 @@ def __call__(self, x):
853853
self.assertEqual(f, f_.func)
854854
self.assertEqual(return_type, f_.returnType)
855855

856-
def test_stopiteration_in_udf(self):
857-
# test for SPARK-23754
858-
from pyspark.sql.functions import udf
859-
from py4j.protocol import Py4JJavaError
860-
861-
def foo(x):
862-
raise StopIteration()
863-
864-
with self.assertRaises(Py4JJavaError) as cm:
865-
self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()
866-
867-
self.assertIn(
868-
"Caught StopIteration thrown from user's code; failing the task",
869-
cm.exception.java_exception.toString()
870-
)
871-
872856
def test_validate_column_types(self):
873857
from pyspark.sql.functions import udf, to_json
874858
from pyspark.sql.column import _to_java_column
@@ -3917,6 +3901,44 @@ def foo(df):
39173901
def foo(k, v):
39183902
return k
39193903

3904+
def test_stopiteration_in_udf(self):
3905+
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
3906+
from py4j.protocol import Py4JJavaError
3907+
3908+
def foo(x):
3909+
raise StopIteration()
3910+
3911+
def foofoo(x, y):
3912+
raise StopIteration()
3913+
3914+
exc_message = "Caught StopIteration thrown from user's code; failing the task"
3915+
df = self.spark.range(0, 100)
3916+
3917+
# plain udf (test for SPARK-23754)
3918+
self.assertRaisesRegexp(
3919+
Py4JJavaError,
3920+
exc_message,
3921+
df.withColumn('v', udf(foo)('id')).collect
3922+
)
3923+
3924+
# pandas scalar udf
3925+
self.assertRaisesRegexp(
3926+
Py4JJavaError,
3927+
exc_message,
3928+
df.withColumn(
3929+
'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
3930+
).collect
3931+
)
3932+
3933+
# pandas grouped map
3934+
self.assertRaisesRegexp(
3935+
Py4JJavaError,
3936+
exc_message,
3937+
df.groupBy('id').apply(
3938+
pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
3939+
).collect
3940+
)
3941+
39203942

39213943
@unittest.skipIf(
39223944
not _have_pandas or not _have_pyarrow,

python/pyspark/sql/udf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from pyspark.sql.column import Column, _to_java_column, _to_seq
2525
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string, \
2626
to_arrow_type, to_arrow_schema
27-
from pyspark.util import fail_on_stopiteration
2827

2928
__all__ = ["UDFRegistration"]
3029

@@ -155,8 +154,7 @@ def _create_judf(self):
155154
spark = SparkSession.builder.getOrCreate()
156155
sc = spark.sparkContext
157156

158-
func = fail_on_stopiteration(self.func)
159-
wrapped_func = _wrap_function(sc, func, self.returnType)
157+
wrapped_func = _wrap_function(sc, self.func, self.returnType)
160158
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
161159
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
162160
self._name, wrapped_func, jdt, self.evalType, self.deterministic)

python/pyspark/tests.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,27 +1270,34 @@ def test_pipe_functions(self):
12701270
self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
12711271
self.assertEqual([], rdd.pipe('grep 4').collect())
12721272

1273-
def test_stopiteration_in_client_code(self):
1273+
def test_stopiteration_in_user_code(self):
12741274

12751275
def stopit(*x):
12761276
raise StopIteration()
12771277

12781278
seq_rdd = self.sc.parallelize(range(10))
12791279
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
1280-
1281-
self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
1282-
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
1283-
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
1284-
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
1285-
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
1286-
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
1287-
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)
1288-
1289-
# the exception raised is non-deterministic
1290-
self.assertRaises((Py4JJavaError, RuntimeError),
1291-
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
1292-
self.assertRaises((Py4JJavaError, RuntimeError),
1293-
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
1280+
msg = "Caught StopIteration thrown from user's code; failing the task"
1281+
1282+
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
1283+
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
1284+
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
1285+
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
1286+
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
1287+
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
1288+
self.assertRaisesRegexp(Py4JJavaError, msg,
1289+
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
1290+
1291+
# these methods call the user function both in the driver and in the executor
1292+
# the exception raised is different according to where the StopIteration happens
1293+
# RuntimeError is raised if in the driver
1294+
# Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
1295+
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
1296+
keyed_rdd.reduceByKeyLocally, stopit)
1297+
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
1298+
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
1299+
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
1300+
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
12941301

12951302

12961303
class ProfilerTests(PySparkTestCase):

python/pyspark/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _exception_message(excp):
4848
def fail_on_stopiteration(f):
4949
"""
5050
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
51-
prevents silent loss of data when 'f' is used in a for loop
51+
prevents silent loss of data when 'f' is used in a for loop in Spark code
5252
"""
5353
def wrapper(*args, **kwargs):
5454
try:

python/pyspark/worker.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
3636
BatchedSerializer, ArrowStreamPandasSerializer
3737
from pyspark.sql.types import to_arrow_type
38+
from pyspark.util import fail_on_stopiteration
3839
from pyspark import shuffle
3940

4041
pickleSer = PickleSerializer()
@@ -122,13 +123,17 @@ def read_single_udf(pickleSer, infile, eval_type):
122123
else:
123124
row_func = chain(row_func, f)
124125

126+
# make sure StopIteration's raised in the user code are not ignored
127+
# when they are processed in a for loop, raise them as RuntimeError's instead
128+
func = fail_on_stopiteration(row_func)
129+
125130
# the last returnType will be the return type of UDF
126131
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
127-
return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
132+
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
128133
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
129-
return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
134+
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type)
130135
else:
131-
return arg_offsets, wrap_udf(row_func, return_type)
136+
return arg_offsets, wrap_udf(func, return_type)
132137

133138

134139
def read_udfs(pickleSer, infile, eval_type):

0 commit comments

Comments
 (0)