From 4bdd86859ff89d460dc8072e158c2eff47022975 Mon Sep 17 00:00:00 2001 From: Daren Liang Date: Wed, 6 Nov 2024 16:29:51 -0500 Subject: [PATCH] fix: remove Category inheritance from ArrowDictionary Signed-off-by: Daren Liang --- pandera/engines/pandas_engine.py | 2 +- pandera/engines/pyarrow_engine.py | 2 +- tests/core/test_pandas_engine.py | 79 ++++++++++++++++++++++--------- 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/pandera/engines/pandas_engine.py b/pandera/engines/pandas_engine.py index 44228f62d..d0619eb4c 100644 --- a/pandera/engines/pandas_engine.py +++ b/pandera/engines/pandas_engine.py @@ -1647,7 +1647,7 @@ def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.TimestampType): equivalents=[pyarrow.dictionary, pyarrow.DictionaryType] ) @immutable(init=True) - class ArrowDictionary(DataType, dtypes.Category): + class ArrowDictionary(DataType): """Semantic representation of a :class:`pyarrow.dictionary`.""" type: Optional[pd.ArrowDtype] = dataclasses.field( diff --git a/pandera/engines/pyarrow_engine.py b/pandera/engines/pyarrow_engine.py index 4f42bf688..a0ad502a8 100644 --- a/pandera/engines/pyarrow_engine.py +++ b/pandera/engines/pyarrow_engine.py @@ -253,7 +253,7 @@ def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.TimestampType): equivalents=[pyarrow.dictionary, pyarrow.DictionaryType] ) @immutable(init=True) -class ArrowDictionary(DataType, dtypes.Category): +class ArrowDictionary(DataType): """Semantic representation of a :class:`pyarrow.dictionary`.""" type: Optional[pd.ArrowDtype] = dataclasses.field(default=None, init=False) diff --git a/tests/core/test_pandas_engine.py b/tests/core/test_pandas_engine.py index bd4187c2a..e22e28a7e 100644 --- a/tests/core/test_pandas_engine.py +++ b/tests/core/test_pandas_engine.py @@ -277,6 +277,10 @@ def test_pandas_date_coerce_dtype(to_df, data): (pd.Series(["foo", "bar", "baz", None]), pyarrow.binary(3)), (pd.Series(["foo", "barbaz", None]), pyarrow.large_binary()), (pd.Series(["1", "1.0", "foo", "bar", None]), pyarrow.large_string()), + ( + pd.Series(["a", "b", "c"]), + pyarrow.dictionary(pyarrow.int64(), pyarrow.string()), + ), ) @@ -289,17 +293,20 @@ def test_pandas_arrow_dtype(data, dtype): pytest.skip("Support of pandas 2.0.0+ with pyarrow only") dtype = pandas_engine.Engine.dtype(dtype) - dtype.coerce(data) + coerced_data = dtype.coerce(data) + assert coerced_data.dtype == dtype.type pandas_arrow_dtype_error_cases = ( ( pd.Series([["a", "b", "c"]]), pyarrow.list_(pyarrow.int64()), + pyarrow.ArrowInvalid, ), ( pd.Series([["a", "b"]]), pyarrow.list_(pyarrow.string(), 3), + pyarrow.ArrowInvalid, ), ( pd.Series([{"foo": 1, "bar": "a"}]), @@ -309,13 +316,22 @@ def test_pandas_arrow_dtype(data, dtype): ("bar", pyarrow.int64()), ] ), + pyarrow.ArrowTypeError, + ), + (pd.Series(["a", "1"]), pyarrow.null, NotImplementedError), + ( + pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), + pyarrow.date32, + pyarrow.ArrowTypeError, + ), + ( + pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), + pyarrow.date64, + pyarrow.ArrowTypeError, ), - (pd.Series(["a", "1"]), pyarrow.null), - (pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date32), - (pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date64), - (pd.Series(["a"]), pyarrow.duration("ns")), - (pd.Series(["a", "b"]), pyarrow.time32("ms")), - (pd.Series(["a", "b"]), pyarrow.time64("ns")), + (pd.Series(["a"]), pyarrow.duration("ns"), ValueError), + (pd.Series(["a", "b"]), pyarrow.time32("ms"), ValueError), + (pd.Series(["a", "b"]), pyarrow.time64("ns"), ValueError), ( pd.Series( [ @@ -324,16 +340,41 @@ def test_pandas_arrow_dtype(data, dtype): ] ), pyarrow.map_(pyarrow.int32(), pyarrow.string()), + NotImplementedError, + ), + (pd.Series([1, "foo", None]), pyarrow.binary(), pyarrow.ArrowInvalid), + ( + pd.Series(["foo", "bar", "baz", None]), + pyarrow.binary(2), + NotImplementedError, + ), + ( + pd.Series([1, "foo", "barbaz", None]), + pyarrow.large_binary(), + pyarrow.ArrowInvalid, + ), + ( + pd.Series([1, 1.0, "foo", "bar", None]), + pyarrow.large_string(), + pyarrow.ArrowInvalid, + ), + ( + pd.Series([1.0, 2.0, 3.0]), + pyarrow.dictionary(pyarrow.int64(), pyarrow.float64()), + NotImplementedError, + ), + ( + pd.Series(["a", "b", "c"]), + pyarrow.dictionary(pyarrow.int64(), pyarrow.int64()), + AssertionError, ), - (pd.Series([1, "foo", None]), pyarrow.binary()), - (pd.Series(["foo", "bar", "baz", None]), pyarrow.binary(2)), - (pd.Series([1, "foo", "barbaz", None]), pyarrow.large_binary()), - (pd.Series([1, 1.0, "foo", "bar", None]), pyarrow.large_string()), ) -@pytest.mark.parametrize(("data", "dtype"), pandas_arrow_dtype_error_cases) -def test_pandas_arrow_dtype_error(data, dtype): +@pytest.mark.parametrize( + ("data", "dtype", "exc"), pandas_arrow_dtype_error_cases +) +def test_pandas_arrow_dtype_error(data, dtype, exc): """Test pyarrow dtype raises Error on bad data.""" if not ( pandas_engine.PYARROW_INSTALLED and pandas_engine.PANDAS_2_0_0_PLUS @@ -341,12 +382,6 @@ def test_pandas_arrow_dtype_error(data, dtype): pytest.skip("Support of pandas 2.0.0+ with pyarrow only") dtype = pandas_engine.Engine.dtype(dtype) - with pytest.raises( - ( - pyarrow.ArrowInvalid, - pyarrow.ArrowTypeError, - NotImplementedError, - ValueError, - ) - ): - dtype.coerce(data) + with pytest.raises(exc): + coerced_data = dtype.coerce(data) + assert coerced_data.dtype == dtype.type