diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 57447d5689201..4888d223b42db 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -658,6 +658,11 @@ "Expected values for ``, got ." ] }, + "UDF_RETURN_TYPE" : { + "message" : [ + "Return type of the user-defined function should be , but is ." + ] + }, "UDTF_EXEC_ERROR" : { "message" : [ "User defined table function encountered an error in the '' method: " diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 6f2c8389a4c12..d8a3f812c33ab 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -399,7 +399,7 @@ class DataFrame(Frame, Generic[T]): `compute.ops_on_diff_frames` should be turned on; 2, when `data` is a local dataset (Pandas DataFrame/numpy ndarray/list/etc), it will first collect the `index` to driver if necessary, and then apply - the `Pandas.DataFrame(...)` creation internally; + the `pandas.DataFrame(...)` creation internally; Examples -------- diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index f835ea57b7751..e82a59573a9c0 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -155,7 +155,7 @@ def wrap_and_init_stream(): class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ - Serializes Pandas.Series as Arrow data with Arrow streaming format. + Serializes pandas.Series as Arrow data with Arrow streaming format. Parameters ---------- diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_map.py b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py index ed51d0d3d1996..868aeaeff7fe6 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py @@ -22,7 +22,8 @@ class ArrowMapParityTests(MapInArrowTestsMixin, ReusedConnectTestCase): - pass + def test_other_than_recordbatch_iter(self): + self.check_other_than_recordbatch_iter() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_map.py b/python/pyspark/sql/tests/connect/test_parity_pandas_map.py index 539fd98266b28..6ff9b0cb33b28 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_map.py @@ -14,16 +14,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import unittest + + from pyspark.sql.tests.pandas.test_pandas_map import MapInPandasTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase class MapInPandasParityTests(MapInPandasTestsMixin, ReusedConnectTestCase): + def test_other_than_dataframe_iter(self): + self.check_other_than_dataframe_iter() + + def test_dataframes_with_other_column_names(self): + self.check_dataframes_with_other_column_names() + + def test_dataframes_with_duplicate_column_names(self): + self.check_dataframes_with_duplicate_column_names() + + def test_dataframes_with_less_columns(self): + self.check_dataframes_with_less_columns() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_dataframes_with_incompatible_types(self): + self.check_dataframes_with_incompatible_types() + def test_empty_dataframes_with_less_columns(self): self.check_empty_dataframes_with_less_columns() - def test_other_than_dataframe(self): - self.check_other_than_dataframe() + def test_empty_dataframes_with_other_columns(self): + self.check_empty_dataframes_with_other_columns() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py index 8def08323bec3..b867156e71a5d 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py @@ -56,7 +56,6 @@ class CogroupedApplyInPandasTestsMixin: def data1(self): return ( self.spark.range(10) - .toDF("id") .withColumn("ks", array([lit(i) for i in range(20, 30)])) .withColumn("k", explode(col("ks"))) .withColumn("v", col("k") * 10) @@ -67,7 +66,6 @@ def data1(self): def data2(self): return ( self.spark.range(10) - .toDF("id") .withColumn("ks", array([lit(i) for i in range(20, 30)])) .withColumn("k", explode(col("ks"))) .withColumn("v2", col("k") * 100) @@ -168,7 +166,7 @@ def check_apply_in_pandas_not_returning_pandas_dataframe(self): fn=lambda lft, rgt: lft.size + rgt.size, error_class=PythonException, error_message_regex="Return type of the user-defined function " - "should be pandas.DataFrame, but is ", + "should be pandas.DataFrame, but is int64.", ) def test_apply_in_pandas_returning_column_names(self): diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 84e61d42843ae..742b3657f6e75 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -79,7 +79,6 @@ class GroupedApplyInPandasTestsMixin: def data(self): return ( self.spark.range(10) - .toDF("id") .withColumn("vs", array([lit(i) for i in range(20, 30)])) .withColumn("v", explode(col("vs"))) .drop("vs") @@ -287,8 +286,7 @@ def test_apply_in_pandas_not_returning_pandas_dataframe(self): def check_apply_in_pandas_not_returning_pandas_dataframe(self): with self.assertRaisesRegex( PythonException, - "Return type of the user-defined function should be pandas.DataFrame, " - "but is ", + "Return type of the user-defined function should be pandas.DataFrame, but is tuple.", ): self._test_apply_in_pandas(lambda key, pdf: key) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index 3d9a90bc81c40..fb2f9214c5d8f 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -42,15 +42,46 @@ cast(str, pandas_requirement_message or pyarrow_requirement_message), ) class MapInPandasTestsMixin: - def test_map_in_pandas(self): + @staticmethod + def identity_dataframes_iter(*columns: str): def func(iterator): for pdf in iterator: assert isinstance(pdf, pd.DataFrame) - assert pdf.columns == ["id"] + assert pdf.columns.tolist() == list(columns) yield pdf + return func + + @staticmethod + def identity_dataframes_wo_column_names_iter(*columns: str): + def func(iterator): + for pdf in iterator: + assert isinstance(pdf, pd.DataFrame) + assert pdf.columns.tolist() == list(columns) + yield pdf.rename(columns=list(pdf.columns).index) + + return func + + @staticmethod + def dataframes_and_empty_dataframe_iter(*columns: str): + def func(iterator): + for pdf in iterator: + yield pdf + # after yielding all elements, also yield an empty dataframe with given columns + yield pd.DataFrame([], columns=list(columns)) + + return func + + def test_map_in_pandas(self): + # test returning iterator of DataFrames + df = self.spark.range(10, numPartitions=3) + actual = df.mapInPandas(self.identity_dataframes_iter("id"), "id long").collect() + expected = df.collect() + self.assertEqual(actual, expected) + + # test returning list of DataFrames df = self.spark.range(10, numPartitions=3) - actual = df.mapInPandas(func, "id long").collect() + actual = df.mapInPandas(lambda it: [pdf for pdf in it], "id long").collect() expected = df.collect() self.assertEqual(actual, expected) @@ -85,6 +116,18 @@ def func(iterator): expected = df.collect() self.assertEqual(actual, expected) + def test_no_column_names(self): + data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")] + df = self.spark.createDataFrame(data, "a int, b string") + + def func(iterator): + for pdf in iterator: + yield pdf.rename(columns=list(pdf.columns).index) + + actual = df.mapInPandas(func, df.schema).collect() + expected = df.collect() + self.assertEqual(actual, expected) + def test_different_output_length(self): def func(iterator): for _ in iterator: @@ -94,20 +137,161 @@ def func(iterator): actual = df.repartition(1).mapInPandas(func, "a long").collect() self.assertEqual(set((r.a for r in actual)), set(range(100))) - def test_other_than_dataframe(self): + def test_other_than_dataframe_iter(self): with QuietTest(self.sc): - self.check_other_than_dataframe() + self.check_other_than_dataframe_iter() - def check_other_than_dataframe(self): - def bad_iter(_): + def check_other_than_dataframe_iter(self): + def no_iter(_): + return 1 + + def bad_iter_elem(_): return iter([1]) with self.assertRaisesRegex( PythonException, - "Return type of the user-defined function should be Pandas.DataFrame, " - "but is ", + "Return type of the user-defined function should be iterator of pandas.DataFrame, " + "but is int.", + ): + (self.spark.range(10, numPartitions=3).mapInPandas(no_iter, "a int").count()) + + with self.assertRaisesRegex( + PythonException, + "Return type of the user-defined function should be iterator of pandas.DataFrame, " + "but is iterator of int.", + ): + (self.spark.range(10, numPartitions=3).mapInPandas(bad_iter_elem, "a int").count()) + + def test_dataframes_with_other_column_names(self): + with QuietTest(self.sc): + self.check_dataframes_with_other_column_names() + + def check_dataframes_with_other_column_names(self): + def dataframes_with_other_column_names(iterator): + for pdf in iterator: + yield pdf.rename(columns={"id": "iid"}) + + with self.assertRaisesRegex( + PythonException, + "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] " + "Column names of the returned pandas.DataFrame do not match " + "specified schema. Missing: id. Unexpected: iid.\n", + ): + ( + self.spark.range(10, numPartitions=3) + .withColumn("value", lit(0)) + .mapInPandas(dataframes_with_other_column_names, "id int, value int") + .collect() + ) + + def test_dataframes_with_duplicate_column_names(self): + with QuietTest(self.sc): + self.check_dataframes_with_duplicate_column_names() + + def check_dataframes_with_duplicate_column_names(self): + def dataframes_with_other_column_names(iterator): + for pdf in iterator: + yield pdf.rename(columns={"id2": "id"}) + + with self.assertRaisesRegex( + PythonException, + "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] " + "Column names of the returned pandas.DataFrame do not match " + "specified schema. Missing: id2.\n", ): - self.spark.range(10, numPartitions=3).mapInPandas(bad_iter, "a int, b string").count() + ( + self.spark.range(10, numPartitions=3) + .withColumn("id2", lit(0)) + .withColumn("value", lit(1)) + .mapInPandas(dataframes_with_other_column_names, "id int, id2 long, value int") + .collect() + ) + + def test_dataframes_with_less_columns(self): + with QuietTest(self.sc): + self.check_dataframes_with_less_columns() + + def check_dataframes_with_less_columns(self): + df = self.spark.range(10, numPartitions=3).withColumn("value", lit(0)) + + with self.assertRaisesRegex( + PythonException, + "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] " + "Column names of the returned pandas.DataFrame do not match " + "specified schema. Missing: id2.\n", + ): + f = self.identity_dataframes_iter("id", "value") + (df.mapInPandas(f, "id int, id2 long, value int").collect()) + + with self.assertRaisesRegex( + PythonException, + "PySparkRuntimeError: \\[RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF\\] " + "Number of columns of the returned pandas.DataFrame doesn't match " + "specified schema. Expected: 3 Actual: 2\n", + ): + f = self.identity_dataframes_wo_column_names_iter("id", "value") + (df.mapInPandas(f, "id int, id2 long, value int").collect()) + + def test_dataframes_with_more_columns(self): + df = self.spark.range(10, numPartitions=3).select( + "id", col("id").alias("value"), col("id").alias("extra") + ) + expected = df.select("id", "value").collect() + + f = self.identity_dataframes_iter("id", "value", "extra") + actual = df.repartition(1).mapInPandas(f, "id long, value long").collect() + self.assertEqual(actual, expected) + + f = self.identity_dataframes_wo_column_names_iter("id", "value", "extra") + actual = df.repartition(1).mapInPandas(f, "id long, value long").collect() + self.assertEqual(actual, expected) + + def test_dataframes_with_incompatible_types(self): + with QuietTest(self.sc): + self.check_dataframes_with_incompatible_types() + + def check_dataframes_with_incompatible_types(self): + def func(iterator): + for pdf in iterator: + yield pdf.assign(id=pdf["id"].apply(str)) + + for safely in [True, False]: + with self.subTest(convertToArrowArraySafely=safely), self.sql_conf( + {"spark.sql.execution.pandas.convertToArrowArraySafely": safely} + ): + # sometimes we see ValueErrors + with self.subTest(convert="string to double"): + expected = ( + r"ValueError: Exception thrown when converting pandas.Series " + r"\(object\) with name 'id' to Arrow Array \(double\)." + ) + if safely: + expected = expected + ( + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + with self.assertRaisesRegex(PythonException, expected + "\n"): + ( + self.spark.range(10, numPartitions=3) + .mapInPandas(func, "id double") + .collect() + ) + + # sometimes we see TypeErrors + with self.subTest(convert="double to string"): + with self.assertRaisesRegex( + PythonException, + r"TypeError: Exception thrown when converting pandas.Series " + r"\(float64\) with name 'id' to Arrow Array \(string\).\n", + ): + ( + self.spark.range(10, numPartitions=3) + .select(col("id").cast("double")) + .mapInPandas(self.identity_dataframes_iter("id"), "id string") + .collect() + ) def test_empty_iterator(self): def empty_iter(_): @@ -124,16 +308,8 @@ def empty_dataframes(_): self.assertEqual(mapped.count(), 0) def test_empty_dataframes_without_columns(self): - def empty_dataframes_wo_columns(iterator): - for pdf in iterator: - yield pdf - # after yielding all elements of the iterator, also yield one dataframe without columns - yield pd.DataFrame([]) - - mapped = ( - self.spark.range(10, numPartitions=3) - .toDF("id") - .mapInPandas(empty_dataframes_wo_columns, "id int") + mapped = self.spark.range(10, numPartitions=3).mapInPandas( + self.dataframes_and_empty_dataframe_iter(), "id int" ) self.assertEqual(mapped.count(), 10) @@ -142,16 +318,47 @@ def test_empty_dataframes_with_less_columns(self): self.check_empty_dataframes_with_less_columns() def check_empty_dataframes_with_less_columns(self): - def empty_dataframes_with_less_columns(iterator): - for pdf in iterator: - yield pdf - # after yielding all elements of the iterator, also yield a dataframe with less columns - yield pd.DataFrame([(1,)], columns=["id"]) + with self.assertRaisesRegex( + PythonException, + "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] " + "Column names of the returned pandas.DataFrame do not match " + "specified schema. Missing: value.\n", + ): + f = self.dataframes_and_empty_dataframe_iter("id") + ( + self.spark.range(10, numPartitions=3) + .withColumn("value", lit(0)) + .mapInPandas(f, "id int, value int") + .collect() + ) - with self.assertRaisesRegex(PythonException, "KeyError: 'value'"): - self.spark.range(10, numPartitions=3).withColumn("value", lit(0)).toDF( - "id", "value" - ).mapInPandas(empty_dataframes_with_less_columns, "id int, value int").collect() + def test_empty_dataframes_with_more_columns(self): + mapped = self.spark.range(10, numPartitions=3).mapInPandas( + self.dataframes_and_empty_dataframe_iter("id", "extra"), "id int" + ) + self.assertEqual(mapped.count(), 10) + + def test_empty_dataframes_with_other_columns(self): + with QuietTest(self.sc): + self.check_empty_dataframes_with_other_columns() + + def check_empty_dataframes_with_other_columns(self): + def empty_dataframes_with_other_columns(iterator): + for _ in iterator: + yield pd.DataFrame({"iid": [], "value": []}) + + with self.assertRaisesRegex( + PythonException, + "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] " + "Column names of the returned pandas.DataFrame do not match " + "specified schema. Missing: id. Unexpected: iid.\n", + ): + ( + self.spark.range(10, numPartitions=3) + .withColumn("value", lit(0)) + .mapInPandas(empty_dataframes_with_other_columns, "id int, value int") + .collect() + ) def test_chain_map_partitions_in_pandas(self): def func(iterator): diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py index 050f2c3266507..15367743585e3 100644 --- a/python/pyspark/sql/tests/test_arrow_map.py +++ b/python/pyspark/sql/tests/test_arrow_map.py @@ -18,6 +18,7 @@ import time import unittest +from pyspark.sql.utils import PythonException from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, @@ -25,6 +26,7 @@ pandas_requirement_message, pyarrow_requirement_message, ) +from pyspark.testing.utils import QuietTest if have_pyarrow: import pyarrow as pa @@ -88,6 +90,31 @@ def func(iterator): actual = df.repartition(1).mapInArrow(func, "a long").collect() self.assertEqual(set((r.a for r in actual)), set(range(100))) + def test_other_than_recordbatch_iter(self): + with QuietTest(self.sc): + self.check_other_than_recordbatch_iter() + + def check_other_than_recordbatch_iter(self): + def not_iter(_): + return 1 + + def bad_iter_elem(_): + return iter([1]) + + with self.assertRaisesRegex( + PythonException, + "Return type of the user-defined function should be iterator " + "of pyarrow.RecordBatch, but is int.", + ): + (self.spark.range(10, numPartitions=3).mapInArrow(not_iter, "a int").count()) + + with self.assertRaisesRegex( + PythonException, + "Return type of the user-defined function should be iterator " + "of pyarrow.RecordBatch, but is iterator of int.", + ): + (self.spark.range(10, numPartitions=3).mapInArrow(bad_iter_elem, "a int").count()) + def test_empty_iterator(self): def empty_iter(_): return iter([]) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8c12312da27bc..8d07772b2148e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -24,6 +24,7 @@ from inspect import currentframe, getframeinfo, getfullargspec import importlib import json +from typing import Iterator # 'resource' is a Unix specific module. has_resource_module = True @@ -110,10 +111,13 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_type(result): if not hasattr(result, "__len__"): - pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series" - raise TypeError( - "Return type of the user-defined function should be " - "{}, but is {}".format(pd_type, type(result)) + pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": pd_type, + "actual": type(result).__name__, + }, ) return result @@ -134,67 +138,136 @@ def verify_result_length(result, length): ) -def wrap_batch_iter_udf(f, return_type): +def wrap_pandas_batch_iter_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) + iter_type_label = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series" - raise TypeError( - "Return type of the user-defined function should be " - "{}, but is {}".format(pd_type, type(result)) + def verify_result(result): + if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": "iterator of {}".format(iter_type_label), + "actual": type(result).__name__, + }, ) return result + def verify_element(elem): + import pandas as pd + + if not isinstance(elem, pd.DataFrame if type(return_type) == StructType else pd.Series): + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": "iterator of {}".format(iter_type_label), + "actual": "iterator of {}".format(type(elem).__name__), + }, + ) + + verify_pandas_result( + elem, return_type, assign_cols_by_name=True, truncate_return_schema=True + ) + + return elem + return lambda *iterator: map( - lambda res: (res, arrow_return_type), map(verify_result_type, f(*iterator)) + lambda res: (res, arrow_return_type), map(verify_element, verify_result(f(*iterator))) ) -def verify_pandas_result(result, return_type, assign_cols_by_name): +def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_return_schema): import pandas as pd - if not isinstance(result, pd.DataFrame): - raise TypeError( - "Return type of the user-defined function should be " - "pandas.DataFrame, but is {}".format(type(result)) - ) + if type(return_type) == StructType: + if not isinstance(result, pd.DataFrame): + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": "pandas.DataFrame", + "actual": type(result).__name__, + }, + ) + + # check the schema of the result only if it is not empty or has columns + if not result.empty or len(result.columns) != 0: + # if any column name of the result is a string + # the column names of the result have to match the return type + # see create_array in pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer + field_names = set([field.name for field in return_type.fields]) + # only the first len(field_names) result columns are considered + # when truncating the return schema + result_columns = ( + result.columns[: len(field_names)] if truncate_return_schema else result.columns + ) + column_names = set(result_columns) + if ( + assign_cols_by_name + and any(isinstance(name, str) for name in result.columns) + and column_names != field_names + ): + missing = sorted(list(field_names.difference(column_names))) + missing = f" Missing: {', '.join(missing)}." if missing else "" - # check the schema of the result only if it is not empty or has columns - if not result.empty or len(result.columns) != 0: - # if any column name of the result is a string - # the column names of the result have to match the return type - # see create_array in pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer - field_names = set([field.name for field in return_type.fields]) - column_names = set(result.columns) - if ( - assign_cols_by_name - and any(isinstance(name, str) for name in result.columns) - and column_names != field_names - ): - missing = sorted(list(field_names.difference(column_names))) - missing = f" Missing: {', '.join(missing)}." if missing else "" - - extra = sorted(list(column_names.difference(field_names))) - extra = f" Unexpected: {', '.join(extra)}." if extra else "" + extra = sorted(list(column_names.difference(field_names))) + extra = f" Unexpected: {', '.join(extra)}." if extra else "" - raise PySparkRuntimeError( - error_class="RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF", + raise PySparkRuntimeError( + error_class="RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF", + message_parameters={ + "missing": missing, + "extra": extra, + }, + ) + # otherwise the number of columns of result have to match the return type + elif len(result_columns) != len(return_type): + raise PySparkRuntimeError( + error_class="RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF", + message_parameters={ + "expected": str(len(return_type)), + "actual": str(len(result.columns)), + }, + ) + else: + if not isinstance(result, pd.Series): + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={"expected": "pandas.Series", "actual": type(result).__name__}, + ) + + +def wrap_arrow_batch_iter_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def verify_result(result): + if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", message_parameters={ - "missing": missing, - "extra": extra, + "expected": "iterator of pyarrow.RecordBatch", + "actual": type(result).__name__, }, ) - # otherwise the number of columns of result have to match the return type - elif len(result.columns) != len(return_type): - raise PySparkRuntimeError( - error_class="RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF", + return result + + def verify_element(elem): + import pyarrow as pa + + if not isinstance(elem, pa.RecordBatch): + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", message_parameters={ - "expected": str(len(return_type)), - "actual": str(len(result.columns)), + "expected": "iterator of pyarrow.RecordBatch", + "actual": "iterator of {}".format(type(elem).__name__), }, ) + return elem + + return lambda *iterator: map( + lambda res: (res, arrow_return_type), map(verify_element, verify_result(f(*iterator))) + ) + def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): _assign_cols_by_name = assign_cols_by_name(runner_conf) @@ -211,7 +284,9 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se key_series = left_key_series if not left_df.empty else right_key_series key = tuple(s[0] for s in key_series) result = f(key, left_df, right_df) - verify_pandas_result(result, return_type, _assign_cols_by_name) + verify_pandas_result( + result, return_type, _assign_cols_by_name, truncate_return_schema=False + ) return result @@ -229,7 +304,9 @@ def wrapped(key_series, value_series): elif len(argspec.args) == 2: key = tuple(s[0] for s in key_series) result = f(key, pd.concat(value_series, axis=1)) - verify_pandas_result(result, return_type, _assign_cols_by_name) + verify_pandas_result( + result, return_type, _assign_cols_by_name, truncate_return_schema=False + ) return result @@ -278,9 +355,12 @@ def wrapped(key_series, value_series_gen, state): def verify_element(result): if not isinstance(result, pd.DataFrame): - raise TypeError( - "The type of element in return iterator of the user-defined function " - "should be pandas.DataFrame, but is {}".format(type(result)) + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": "iterator of pandas.DataFrame", + "actual": "iterator of {}".format(type(result).__name__), + }, ) # the number of columns of result have to match the return type # but it is fine for result to have no columns at all if it is empty @@ -299,17 +379,20 @@ def verify_element(result): return result if isinstance(result_iter, pd.DataFrame): - raise TypeError( - "Return type of the user-defined function should be " - "iterable of pandas.DataFrame, but is {}".format(type(result_iter)) + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": "iterable of pandas.DataFrame", + "actual": type(result_iter).__name__, + }, ) try: iter(result_iter) except TypeError: - raise TypeError( - "Return type of the user-defined function should be " - "iterable, but is {}".format(type(result_iter)) + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={"expected": "iterable", "actual": type(result_iter).__name__}, ) result_iter_with_validation = (verify_element(x) for x in result_iter) @@ -423,11 +506,11 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF): return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: - return arg_offsets, wrap_batch_iter_udf(func, return_type) + return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: - return arg_offsets, wrap_batch_iter_udf(func, return_type) + return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: - return arg_offsets, wrap_batch_iter_udf(func, return_type) + return arg_offsets, wrap_arrow_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) @@ -544,7 +627,9 @@ def verify_result(result): ) # Verify the type and the schema of the result. - verify_pandas_result(result, return_type, assign_cols_by_name=False) + verify_pandas_result( + result, return_type, assign_cols_by_name=False, truncate_return_schema=False + ) return result return lambda *a: map(lambda res: (res, arrow_return_type), map(verify_result, f(*a)))