diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index ca249c75ea5c8..30c2d10245650 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -231,18 +231,25 @@ def create_array(s, t): s = s.astype(s.dtypes.categories.dtype) try: array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) + except TypeError as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + raise TypeError(error_msg % (s.dtype, s.name, t)) from e except ValueError as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) if self._safecheck: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) to " - + "Arrow Array (%s). It can be caused by overflows or other " - + "unsafe conversions warned by Arrow. Arrow safe type check " - + "can be disabled by using SQL config " - + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + error_msg = error_msg + ( + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." ) - raise ValueError(error_msg % (s.dtype, t)) from e - else: - raise e + raise ValueError(error_msg % (s.dtype, s.name, t)) from e return array arrs = [] @@ -265,7 +272,9 @@ def create_array(s, t): # Assign result columns by position else: arrs_names = [ - (create_array(s[s.columns[i]], field.type), field.name) + # the selected series has name '1', so we rename it to field.name + # as the name is used by create_array to provide a meaningful error message + (create_array(s[s.columns[i]].rename(field.name), field.type), field.name) for i, field in enumerate(t) ] diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py index 5cbc9e1caa430..47ed12d2f466e 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py @@ -43,7 +43,7 @@ not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), ) -class CogroupedMapInPandasTests(ReusedSQLTestCase): +class CogroupedApplyInPandasTests(ReusedSQLTestCase): @property def data1(self): return ( @@ -79,7 +79,9 @@ def test_right_group_empty(self): def test_different_schemas(self): right = self.data2.withColumn("v3", lit("a")) - self._test_merge(self.data1, right, "id long, k int, v int, v2 int, v3 string") + self._test_merge( + self.data1, right, output_schema="id long, k int, v int, v2 int, v3 string" + ) def test_different_keys(self): left = self.data1 @@ -128,26 +130,7 @@ def merge_pandas(lft, rgt): assert_frame_equal(expected, result) def test_empty_group_by(self): - left = self.data1 - right = self.data2 - - def merge_pandas(lft, rgt): - return pd.merge(lft, rgt, on=["id", "k"]) - - result = ( - left.groupby() - .cogroup(right.groupby()) - .applyInPandas(merge_pandas, "id long, k int, v int, v2 int") - .sort(["id", "k"]) - .toPandas() - ) - - left = left.toPandas() - right = right.toPandas() - - expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"]) - - assert_frame_equal(expected, result) + self._test_merge(self.data1, self.data2, by=[]) def test_different_group_key_cardinality(self): left = self.data1 @@ -166,29 +149,35 @@ def merge_pandas(lft, _): ) def test_apply_in_pandas_not_returning_pandas_dataframe(self): - left = self.data1 - right = self.data2 + self._test_merge_error( + fn=lambda lft, rgt: lft.size + rgt.size, + error_class=PythonException, + error_message_regex="Return type of the user-defined function " + "should be pandas.DataFrame, but is ", + ) + + def test_apply_in_pandas_returning_column_names(self): + self._test_merge(fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"])) + def test_apply_in_pandas_returning_no_column_names(self): def merge_pandas(lft, rgt): - return lft.size + rgt.size + res = pd.merge(lft, rgt, on=["id", "k"]) + res.columns = range(res.columns.size) + return res - with QuietTest(self.sc): - with self.assertRaisesRegex( - PythonException, - "Return type of the user-defined function should be pandas.DataFrame, " - "but is ", - ): - ( - left.groupby("id") - .cogroup(right.groupby("id")) - .applyInPandas(merge_pandas, "id long, k int, v int, v2 int") - .collect() - ) + self._test_merge(fn=merge_pandas) - def test_apply_in_pandas_returning_wrong_number_of_columns(self): - left = self.data1 - right = self.data2 + def test_apply_in_pandas_returning_column_names_sometimes(self): + def merge_pandas(lft, rgt): + res = pd.merge(lft, rgt, on=["id", "k"]) + if 0 in lft["id"] and lft["id"][0] % 2 == 0: + return res + res.columns = range(res.columns.size) + return res + + self._test_merge(fn=merge_pandas) + def test_apply_in_pandas_returning_wrong_column_names(self): def merge_pandas(lft, rgt): if 0 in lft["id"] and lft["id"][0] % 2 == 0: lft["add"] = 0 @@ -196,70 +185,77 @@ def merge_pandas(lft, rgt): rgt["more"] = 1 return pd.merge(lft, rgt, on=["id", "k"]) - with QuietTest(self.sc): - with self.assertRaisesRegex( - PythonException, - "Number of columns of the returned pandas.DataFrame " - "doesn't match specified schema. Expected: 4 Actual: 6", - ): - ( - # merge_pandas returns two columns for even keys while we set schema to four - left.groupby("id") - .cogroup(right.groupby("id")) - .applyInPandas(merge_pandas, "id long, k int, v int, v2 int") - .collect() - ) - - def test_apply_in_pandas_returning_empty_dataframe(self): - left = self.data1 - right = self.data2 + self._test_merge_error( + fn=merge_pandas, + error_class=PythonException, + error_message_regex="Column names of the returned pandas.DataFrame " + "do not match specified schema. Unexpected: add, more.\n", + ) + def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self): def merge_pandas(lft, rgt): if 0 in lft["id"] and lft["id"][0] % 2 == 0: - return pd.DataFrame([]) + lft[3] = 0 if 0 in rgt["id"] and rgt["id"][0] % 3 == 0: - return pd.DataFrame([]) - return pd.merge(lft, rgt, on=["id", "k"]) - - result = ( - left.groupby("id") - .cogroup(right.groupby("id")) - .applyInPandas(merge_pandas, "id long, k int, v int, v2 int") - .sort(["id", "k"]) - .toPandas() + rgt[3] = 1 + res = pd.merge(lft, rgt, on=["id", "k"]) + res.columns = range(res.columns.size) + return res + + self._test_merge_error( + fn=merge_pandas, + error_class=PythonException, + error_message_regex="Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. Expected: 4 Actual: 6\n", ) - left = left.toPandas() - right = right.toPandas() - - expected = pd.merge( - left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id", "k"] - ).sort_values(by=["id", "k"]) - - assert_frame_equal(expected, result) - - def test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self): - left = self.data1 - right = self.data2 - + def test_apply_in_pandas_returning_empty_dataframe(self): def merge_pandas(lft, rgt): if 0 in lft["id"] and lft["id"][0] % 2 == 0: - return pd.DataFrame([], columns=["id", "k"]) + return pd.DataFrame() + if 0 in rgt["id"] and rgt["id"][0] % 3 == 0: + return pd.DataFrame() return pd.merge(lft, rgt, on=["id", "k"]) - with QuietTest(self.sc): - with self.assertRaisesRegex( - PythonException, - "Number of columns of the returned pandas.DataFrame doesn't " - "match specified schema. Expected: 4 Actual: 2", - ): - ( - # merge_pandas returns two columns for even keys while we set schema to four - left.groupby("id") - .cogroup(right.groupby("id")) - .applyInPandas(merge_pandas, "id long, k int, v int, v2 int") - .collect() - ) + self._test_merge_empty(fn=merge_pandas) + + def test_apply_in_pandas_returning_incompatible_type(self): + for safely in [True, False]: + with self.subTest(convertToArrowArraySafely=safely), self.sql_conf( + {"spark.sql.execution.pandas.convertToArrowArraySafely": safely} + ), QuietTest(self.sc): + # sometimes we see ValueErrors + with self.subTest(convert="string to double"): + expected = ( + r"ValueError: Exception thrown when converting pandas.Series \(object\) " + r"with name 'k' to Arrow Array \(double\)." + ) + if safely: + expected = expected + ( + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + self._test_merge_error( + fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["2.0"]}), + output_schema="id long, k double", + error_class=PythonException, + error_message_regex=expected, + ) + + # sometimes we see TypeErrors + with self.subTest(convert="double to string"): + expected = ( + r"TypeError: Exception thrown when converting pandas.Series \(float64\) " + r"with name 'k' to Arrow Array \(string\).\n" + ) + self._test_merge_error( + fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": [2.0]}), + output_schema="id long, k string", + error_class=PythonException, + error_message_regex=expected, + ) def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self): df = self.spark.range(0, 10).toDF("v1") @@ -312,23 +308,20 @@ def left_assign_key(key, lft, _): def test_wrong_return_type(self): # Test that we get a sensible exception invalid values passed to apply - left = self.data1 - right = self.data2 - with QuietTest(self.sc): - with self.assertRaisesRegex( - NotImplementedError, "Invalid return type.*ArrayType.*TimestampType" - ): - left.groupby("id").cogroup(right.groupby("id")).applyInPandas( - lambda l, r: l, "id long, v array" - ) + self._test_merge_error( + fn=lambda l, r: l, + output_schema="id long, v array", + error_class=NotImplementedError, + error_message_regex="Invalid return type.*ArrayType.*TimestampType", + ) def test_wrong_args(self): - left = self.data1 - right = self.data2 - with self.assertRaisesRegex(ValueError, "Invalid function"): - left.groupby("id").cogroup(right.groupby("id")).applyInPandas( - lambda: 1, StructType([StructField("d", DoubleType())]) - ) + self.__test_merge_error( + fn=lambda: 1, + output_schema=StructType([StructField("d", DoubleType())]), + error_class=ValueError, + error_message_regex="Invalid function", + ) def test_case_insensitive_grouping_column(self): # SPARK-31915: case-insensitive grouping column should work. @@ -434,15 +427,51 @@ def right_assign_key(key, lft, rgt): assert_frame_equal(expected, result) - @staticmethod - def _test_merge(left, right, output_schema="id long, k int, v int, v2 int"): - def merge_pandas(lft, rgt): - return pd.merge(lft, rgt, on=["id", "k"]) + def _test_merge_empty(self, fn): + left = self.data1.toPandas() + right = self.data2.toPandas() + + expected = pd.merge( + left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id", "k"] + ).sort_values(by=["id", "k"]) + + self._test_merge(self.data1, self.data2, fn=fn, expected=expected) + + def _test_merge( + self, + left=None, + right=None, + by=["id"], + fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]), + output_schema="id long, k int, v int, v2 int", + expected=None, + ): + def fn_with_key(_, lft, rgt): + return fn(lft, rgt) + + # Test fn with and without key argument + with self.subTest("without key"): + self.__test_merge(left, right, by, fn, output_schema, expected) + with self.subTest("with key"): + self.__test_merge(left, right, by, fn_with_key, output_schema, expected) + + def __test_merge( + self, + left=None, + right=None, + by=["id"], + fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]), + output_schema="id long, k int, v int, v2 int", + expected=None, + ): + # Test fn as is, cf. _test_merge + left = self.data1 if left is None else left + right = self.data2 if right is None else right result = ( - left.groupby("id") - .cogroup(right.groupby("id")) - .applyInPandas(merge_pandas, output_schema) + left.groupby(*by) + .cogroup(right.groupby(*by)) + .applyInPandas(fn, output_schema) .sort(["id", "k"]) .toPandas() ) @@ -450,10 +479,64 @@ def merge_pandas(lft, rgt): left = left.toPandas() right = right.toPandas() - expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"]) + expected = ( + pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"]) + if expected is None + else expected + ) assert_frame_equal(expected, result) + def _test_merge_error( + self, + error_class, + error_message_regex, + left=None, + right=None, + by=["id"], + fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]), + output_schema="id long, k int, v int, v2 int", + ): + def fn_with_key(_, lft, rgt): + return fn(lft, rgt) + + # Test fn with and without key argument + with self.subTest("without key"): + self.__test_merge_error( + left=left, + right=right, + by=by, + fn=fn, + output_schema=output_schema, + error_class=error_class, + error_message_regex=error_message_regex, + ) + with self.subTest("with key"): + self.__test_merge_error( + left=left, + right=right, + by=by, + fn=fn_with_key, + output_schema=output_schema, + error_class=error_class, + error_message_regex=error_message_regex, + ) + + def __test_merge_error( + self, + error_class, + error_message_regex, + left=None, + right=None, + by=["id"], + fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]), + output_schema="id long, k int, v int, v2 int", + ): + # Test fn as is, cf. _test_merge_error + with QuietTest(self.sc): + with self.assertRaisesRegex(error_class, error_message_regex): + self.__test_merge(left, right, by, fn, output_schema) + if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import * # noqa: F401 diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 5f103c97926a3..88e68b043035e 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -73,7 +73,7 @@ not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), ) -class GroupedMapInPandasTests(ReusedSQLTestCase): +class GroupedApplyInPandasTests(ReusedSQLTestCase): @property def data(self): return ( @@ -270,79 +270,101 @@ def normalize(pdf): assert_frame_equal(expected, result) def test_apply_in_pandas_not_returning_pandas_dataframe(self): - df = self.data - - def stats(key, _): - return key - with QuietTest(self.sc): with self.assertRaisesRegex( PythonException, "Return type of the user-defined function should be pandas.DataFrame, " "but is ", ): - df.groupby("id").applyInPandas(stats, schema="id integer, m double").collect() + self._test_apply_in_pandas(lambda key, pdf: key) - def test_apply_in_pandas_returning_wrong_number_of_columns(self): - df = self.data + @staticmethod + def stats_with_column_names(key, pdf): + # order of column can be different to applyInPandas schema when column names are given + return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"]) - def stats(key, pdf): - v = pdf.v - # returning three columns - res = pd.DataFrame([key + (v.mean(), v.std())]) - return res + @staticmethod + def stats_with_no_column_names(key, pdf): + # columns must be in order of applyInPandas schema when no columns given + return pd.DataFrame([key + (pdf.v.mean(),)]) - with QuietTest(self.sc): - with self.assertRaisesRegex( - PythonException, - "Number of columns of the returned pandas.DataFrame doesn't match " - "specified schema. Expected: 2 Actual: 3", - ): - # stats returns three columns while here we set schema with two columns - df.groupby("id").applyInPandas(stats, schema="id integer, m double").collect() + def test_apply_in_pandas_returning_column_names(self): + self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_column_names) - def test_apply_in_pandas_returning_empty_dataframe(self): - df = self.data + def test_apply_in_pandas_returning_no_column_names(self): + self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_no_column_names) - def odd_means(key, pdf): - if key[0] % 2 == 0: - return pd.DataFrame([]) + def test_apply_in_pandas_returning_column_names_sometimes(self): + def stats(key, pdf): + if key[0] % 2: + return GroupedApplyInPandasTests.stats_with_column_names(key, pdf) else: - return pd.DataFrame([key + (pdf.v.mean(),)]) + return GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf) - expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 != 0} + self._test_apply_in_pandas(stats) - result = ( - df.groupby("id") - .applyInPandas(odd_means, schema="id integer, m double") - .sort("id", "m") - .collect() - ) - - actual_ids = {row[0] for row in result} - self.assertSetEqual(expected_ids, actual_ids) - - self.assertEqual(len(expected_ids), len(result)) - for row in result: - self.assertEqual(24.5, row[1]) - - def test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self): - df = self.data - - def odd_means(key, pdf): - if key[0] % 2 == 0: - return pd.DataFrame([], columns=["id"]) - else: - return pd.DataFrame([key + (pdf.v.mean(),)]) + def test_apply_in_pandas_returning_wrong_column_names(self): + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + "Column names of the returned pandas.DataFrame do not match specified schema. " + "Missing: mean. Unexpected: median, std.\n", + ): + self._test_apply_in_pandas( + lambda key, pdf: pd.DataFrame( + [key + (pdf.v.median(), pdf.v.std())], columns=["id", "median", "std"] + ) + ) + def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self): with QuietTest(self.sc): with self.assertRaisesRegex( PythonException, "Number of columns of the returned pandas.DataFrame doesn't match " - "specified schema. Expected: 2 Actual: 1", + "specified schema. Expected: 2 Actual: 3\n", ): - # stats returns one column for even keys while here we set schema with two columns - df.groupby("id").applyInPandas(odd_means, schema="id integer, m double").collect() + self._test_apply_in_pandas( + lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(), pdf.v.std())]) + ) + + def test_apply_in_pandas_returning_empty_dataframe(self): + self._test_apply_in_pandas_returning_empty_dataframe(pd.DataFrame()) + + def test_apply_in_pandas_returning_incompatible_type(self): + for safely in [True, False]: + with self.subTest(convertToArrowArraySafely=safely), self.sql_conf( + {"spark.sql.execution.pandas.convertToArrowArraySafely": safely} + ), QuietTest(self.sc): + # sometimes we see ValueErrors + with self.subTest(convert="string to double"): + expected = ( + r"ValueError: Exception thrown when converting pandas.Series \(object\) " + r"with name 'mean' to Arrow Array \(double\)." + ) + if safely: + expected = expected + ( + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + with self.assertRaisesRegex(PythonException, expected + "\n"): + self._test_apply_in_pandas( + lambda key, pdf: pd.DataFrame([key + (str(pdf.v.mean()),)]), + output_schema="id long, mean double", + ) + + # sometimes we see TypeErrors + with self.subTest(convert="double to string"): + with self.assertRaisesRegex( + PythonException, + r"TypeError: Exception thrown when converting pandas.Series \(float64\) " + r"with name 'mean' to Arrow Array \(string\).\n", + ): + self._test_apply_in_pandas( + lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),)]), + output_schema="id long, mean string", + ) def test_datatype_string(self): df = self.data @@ -566,7 +588,11 @@ def invalid_positional_types(pdf): with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}): with QuietTest(self.sc): - with self.assertRaisesRegex(Exception, "KeyError: 'id'"): + with self.assertRaisesRegex( + PythonException, + "RuntimeError: Column names of the returned pandas.DataFrame do not match " + "specified schema. Missing: id. Unexpected: iid.\n", + ): grouped_df.apply(column_name_typo).collect() with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"): grouped_df.apply(invalid_positional_types).collect() @@ -655,10 +681,11 @@ def f(pdf): df.groupby("group", window("ts", "5 days")) .applyInPandas(f, df.schema) .select("id", "result") + .orderBy("id") .collect() ) - for r in result: - self.assertListEqual(expected[r[0]], r[1]) + + self.assertListEqual([Row(id=key, result=val) for key, val in expected.items()], result) def test_grouped_over_window_with_key(self): @@ -720,11 +747,11 @@ def f(key, pdf): df.groupby("group", window("ts", "5 days")) .applyInPandas(f, df.schema) .select("id", "result") + .orderBy("id") .collect() ) - for r in result: - self.assertListEqual(expected[r[0]], r[1]) + self.assertListEqual([Row(id=key, result=val) for key, val in expected.items()], result) def test_case_insensitive_grouping_column(self): # SPARK-31915: case-insensitive grouping column should work. @@ -739,6 +766,44 @@ def my_pandas_udf(pdf): ) self.assertEqual(row.asDict(), Row(column=1, score=0.5).asDict()) + def _test_apply_in_pandas(self, f, output_schema="id long, mean double"): + df = self.data + + result = ( + df.groupby("id").applyInPandas(f, schema=output_schema).sort("id", "mean").toPandas() + ) + expected = df.select("id").distinct().withColumn("mean", lit(24.5)).toPandas() + + assert_frame_equal(expected, result) + + def _test_apply_in_pandas_returning_empty_dataframe(self, empty_df): + """Tests some returned DataFrames are empty.""" + df = self.data + + def stats(key, pdf): + if key[0] % 2 == 0: + return GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf) + return empty_df + + result = ( + df.groupby("id") + .applyInPandas(stats, schema="id long, mean double") + .sort("id", "mean") + .collect() + ) + + actual_ids = {row[0] for row in result} + expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 == 0} + self.assertSetEqual(expected_ids, actual_ids) + self.assertEqual(len(expected_ids), len(result)) + for row in result: + self.assertEqual(24.5, row[1]) + + def _test_apply_in_pandas_returning_empty_dataframe_error(self, empty_df, error): + with QuietTest(self.sc): + with self.assertRaisesRegex(PythonException, error): + self._test_apply_in_pandas_returning_empty_dataframe(empty_df) + if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_grouped_map import * # noqa: F401 diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py index 655f0bf151d53..9600d1e344575 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py @@ -53,7 +53,7 @@ not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), ) -class GroupedMapInPandasWithStateTests(ReusedSQLTestCase): +class GroupedApplyInPandasWithStateTests(ReusedSQLTestCase): @classmethod def conf(cls): cfg = SparkConf() diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 6083f31ac81b9..c61994380e628 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -465,9 +465,24 @@ def test_createDataFrame_with_incorrect_schema(self): wrong_schema = StructType(fields) with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}): with QuietTest(self.sc): - with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"): + with self.assertRaises(Exception) as context: self.spark.createDataFrame(pdf, schema=wrong_schema) + # the exception provides us with the column that is incorrect + exception = context.exception + self.assertTrue(hasattr(exception, "args")) + self.assertEqual(len(exception.args), 1) + self.assertRegex( + exception.args[0], + "with name '7_date_t' " "to Arrow Array \\(decimal128\\(38, 18\\)\\)", + ) + + # the inner exception provides us with the incorrect types + exception = exception.__context__ + self.assertTrue(hasattr(exception, "args")) + self.assertEqual(len(exception.args), 1) + self.assertRegex(exception.args[0], "[D|d]ecimal.*got.*date") + def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() new_names = list(map(str, range(len(self.schema.fieldNames())))) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c1c3669701f7e..f7d98a9a18c03 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -146,7 +146,49 @@ def verify_result_type(result): ) -def wrap_cogrouped_map_pandas_udf(f, return_type, argspec): +def verify_pandas_result(result, return_type, assign_cols_by_name): + import pandas as pd + + if not isinstance(result, pd.DataFrame): + raise TypeError( + "Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result)) + ) + + # check the schema of the result only if it is not empty or has columns + if not result.empty or len(result.columns) != 0: + # if any column name of the result is a string + # the column names of the result have to match the return type + # see create_array in pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer + field_names = set([field.name for field in return_type.fields]) + column_names = set(result.columns) + if ( + assign_cols_by_name + and any(isinstance(name, str) for name in result.columns) + and column_names != field_names + ): + missing = sorted(list(field_names.difference(column_names))) + missing = f" Missing: {', '.join(missing)}." if missing else "" + + extra = sorted(list(column_names.difference(field_names))) + extra = f" Unexpected: {', '.join(extra)}." if extra else "" + + raise RuntimeError( + "Column names of the returned pandas.DataFrame do not match specified schema." + "{}{}".format(missing, extra) + ) + # otherwise the number of columns of result have to match the return type + elif len(result.columns) != len(return_type): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) + ) + + +def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): + _assign_cols_by_name = assign_cols_by_name(runner_conf) + def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): import pandas as pd @@ -159,27 +201,16 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se key_series = left_key_series if not left_df.empty else right_key_series key = tuple(s[0] for s in key_series) result = f(key, left_df, right_df) - if not isinstance(result, pd.DataFrame): - raise TypeError( - "Return type of the user-defined function should be " - "pandas.DataFrame, but is {}".format(type(result)) - ) - # the number of columns of result have to match the return type - # but it is fine for result to have no columns at all if it is empty - if not ( - len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty - ): - raise RuntimeError( - "Number of columns of the returned pandas.DataFrame " - "doesn't match specified schema. " - "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) - ) + verify_pandas_result(result, return_type, _assign_cols_by_name) + return result return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), to_arrow_type(return_type))] -def wrap_grouped_map_pandas_udf(f, return_type, argspec): +def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): + _assign_cols_by_name = assign_cols_by_name(runner_conf) + def wrapped(key_series, value_series): import pandas as pd @@ -188,22 +219,8 @@ def wrapped(key_series, value_series): elif len(argspec.args) == 2: key = tuple(s[0] for s in key_series) result = f(key, pd.concat(value_series, axis=1)) + verify_pandas_result(result, return_type, _assign_cols_by_name) - if not isinstance(result, pd.DataFrame): - raise TypeError( - "Return type of the user-defined function should be " - "pandas.DataFrame, but is {}".format(type(result)) - ) - # the number of columns of result have to match the return type - # but it is fine for result to have no columns at all if it is empty - if not ( - len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty - ): - raise RuntimeError( - "Number of columns of the returned pandas.DataFrame " - "doesn't match specified schema. " - "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) - ) return result return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] @@ -396,12 +413,12 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): return arg_offsets, wrap_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it - return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) + return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -412,6 +429,16 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): raise ValueError("Unknown eval type: {}".format(eval_type)) +# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType +def assign_cols_by_name(runner_conf): + return ( + runner_conf.get( + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true" + ).lower() + == "true" + ) + + def read_udfs(pickleSer, infile, eval_type): runner_conf = {} @@ -444,16 +471,9 @@ def read_udfs(pickleSer, infile, eval_type): runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() == "true" ) - # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType - assign_cols_by_name = ( - runner_conf.get( - "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true" - ).lower() - == "true" - ) if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: - ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) + ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name(runner_conf)) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: arrow_max_records_per_batch = runner_conf.get( "spark.sql.execution.arrow.maxRecordsPerBatch", 10000 @@ -463,7 +483,7 @@ def read_udfs(pickleSer, infile, eval_type): ser = ApplyInPandasWithStateSerializer( timezone, safecheck, - assign_cols_by_name, + assign_cols_by_name(runner_conf), state_object_schema, arrow_max_records_per_batch, ) @@ -478,7 +498,7 @@ def read_udfs(pickleSer, infile, eval_type): or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF ) ser = ArrowStreamPandasUDFSerializer( - timezone, safecheck, assign_cols_by_name, df_for_struct + timezone, safecheck, assign_cols_by_name(runner_conf), df_for_struct ) else: ser = BatchedSerializer(CPickleSerializer(), 100)