diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 958314c5741ff..e3e9272a33af6 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -25,6 +25,7 @@ import pyspark.pandas as ps from pyspark.pandas.exceptions import PandasNotImplementedError +from pyspark.pandas.exceptions import SparkPandasNotImplementedError from pyspark.pandas.missing.indexes import ( MissingPandasLikeDatetimeIndex, MissingPandasLikeIndex, @@ -65,6 +66,14 @@ def test_index_basic(self): with self.assertRaisesRegexp(ValueError, "The truth value of a Int64Index is ambiguous."): bool(ps.Index([1])) + # Negative + with self.assertRaises(SparkPandasNotImplementedError): + ps.Index([1, '2']) + with self.assertRaises(SparkPandasNotImplementedError): + ps.Index([[1, '2'], ['A', 'B']]) + with self.assertRaises(SparkPandasNotImplementedError): + ps.Index([[1, 'A'], [2, 'B']]) + def test_index_from_series(self): pser = pd.Series([1, 2, 3], name="a", index=[10, 20, 30]) psser = ps.from_pandas(pser) diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 1361c44404a3d..35af21fdc7d99 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -33,6 +33,7 @@ from pyspark import pandas as ps from pyspark.pandas.config import option_context from pyspark.pandas.exceptions import PandasNotImplementedError +from pyspark.pandas.exceptions import SparkPandasNotImplementedError from pyspark.pandas.frame import CachedDataFrame from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame from pyspark.pandas.typedef.typehints import ( @@ -116,6 +117,14 @@ def test_creation_index(self): with self.assertRaisesRegex(TypeError, err_msg): ps.DataFrame([1, 2], index=ps.MultiIndex.from_tuples([(1, 3), (2, 4)])) + # Negative + with self.assertRaises(SparkPandasNotImplementedError): + ps.DataFrame([1, 2], index=[1, '2']) + with self.assertRaises(SparkPandasNotImplementedError): + ps.DataFrame([1, 2], index=[[1, '2'], ['A', 'B']]) + with self.assertRaises(SparkPandasNotImplementedError): + ps.DataFrame([1, 2], index=[[1, 'A'], [2, 'B']]) + def _check_extension(self, psdf, pdf): if LooseVersion("1.1") <= LooseVersion(pd.__version__) < LooseVersion("1.2.2"): self.assert_eq(psdf, pdf, check_exact=False) diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index 144df0f986a70..229659d7ff077 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -36,6 +36,7 @@ ) from pyspark.testing.sqlutils import SQLTestUtils from pyspark.pandas.exceptions import PandasNotImplementedError +from pyspark.pandas.exceptions import SparkPandasNotImplementedError from pyspark.pandas.missing.series import MissingPandasLikeSeries from pyspark.pandas.typedef.typehints import ( extension_dtypes, @@ -3325,6 +3326,11 @@ def test_transform(self): ): psser.transform(lambda x: x + 1, axis=1) + def test_series_creation(self): + # Negative + with self.assertRaises(SparkPandasNotImplementedError): + ps.Series([1, 2, '3']) + if __name__ == "__main__": from pyspark.pandas.tests.test_series import * # noqa: F401 diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py index 8a32a14b64e72..664fda93285f7 100644 --- a/python/pyspark/pandas/typedef/typehints.py +++ b/python/pyspark/pandas/typedef/typehints.py @@ -32,6 +32,8 @@ from pandas.api.types import CategoricalDtype, pandas_dtype # type: ignore[attr-defined] from pandas.api.extensions import ExtensionDtype +from pyspark.pandas.exceptions import SparkPandasNotImplementedError + extension_dtypes: Tuple[type, ...] try: from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype @@ -357,7 +359,18 @@ def infer_pd_series_spark_type( elif hasattr(pser.iloc[0], "__UDT__"): return pser.iloc[0].__UDT__ else: - return from_arrow_type(pa.Array.from_pandas(pser).type, prefer_timestamp_ntz) + try: + internal_frame = pa.Array.from_pandas(pser) + except (pa.lib.ArrowInvalid, pa.lib.ArrowTypeError): + raise SparkPandasNotImplementedError( + description="PySpark requires elements of homogeneous type for DataFrame, " + "Series and Index, such as .Object([typeA, typeB]), which is " + "supported by Pandas, but not in PySpark, you need to keep the " + "values as the same dtype in PySpark. " + "Got {} and dtype ({}).".format( + str(pser.values), str(pser.dtypes)) + ) + return from_arrow_type(internal_frame.type, prefer_timestamp_ntz) elif isinstance(dtype, CategoricalDtype): if isinstance(pser.dtype, CategoricalDtype): return as_spark_type(pser.cat.codes.dtype, prefer_timestamp_ntz=prefer_timestamp_ntz)