From 19650a3ee8d2d355a51238a9b50efcd96608d223 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 13 Aug 2024 15:46:03 -0700 Subject: [PATCH] BUG (string): ArrowEA comparisons with mismatched types (#59505) * BUG: ArrowEA comparisons with mismatched types * move whatsnew * GH ref --- pandas/core/arrays/arrow/array.py | 8 ++++++- pandas/core/arrays/string_arrow.py | 6 +---- pandas/tests/series/test_logical_ops.py | 31 +++++++++++++++++++++---- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 6c44b7759f0e29..46f2cbb2ebeefc 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -704,7 +704,13 @@ def _cmp_method(self, other, op): if isinstance( other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray) ) or isinstance(getattr(other, "dtype", None), CategoricalDtype): - result = pc_func(self._pa_array, self._box_pa(other)) + try: + result = pc_func(self._pa_array, self._box_pa(other)) + except pa.ArrowNotImplementedError: + # TODO: could this be wrong if other is object dtype? + # in which case we need to operate pointwise? + result = ops.invalid_comparison(self, other, op) + result = pa.array(result, type=pa.bool_()) elif is_scalar(other): try: result = pc_func(self._pa_array, self._box_pa(other)) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 7119f6321a7ffb..00d92958cc8dc4 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -37,7 +37,6 @@ BaseStringArray, StringDtype, ) -from pandas.core.ops import invalid_comparison from pandas.core.strings.object_array import ObjectStringArrayMixin if not pa_version_under10p1: @@ -565,10 +564,7 @@ def _convert_int_dtype(self, result): return result def _cmp_method(self, other, op): - try: - result = super()._cmp_method(other, op) - except pa.ArrowNotImplementedError: - return invalid_comparison(self, other, op) + result = super()._cmp_method(other, op) if op == operator.ne: return result.to_numpy(np.bool_, na_value=True) else: diff --git a/pandas/tests/series/test_logical_ops.py b/pandas/tests/series/test_logical_ops.py index d3a718be47c18d..ff21427b71cf9e 100644 --- a/pandas/tests/series/test_logical_ops.py +++ b/pandas/tests/series/test_logical_ops.py @@ -9,6 +9,7 @@ from pandas.compat import HAS_PYARROW from pandas import ( + ArrowDtype, DataFrame, Index, Series, @@ -539,18 +540,38 @@ def test_int_dtype_different_index_not_bool(self): result = ser1 ^ ser2 tm.assert_series_equal(result, expected) + # TODO: this belongs in comparison tests def test_pyarrow_numpy_string_invalid(self): # GH#56008 - pytest.importorskip("pyarrow") + pa = pytest.importorskip("pyarrow") ser = Series([False, True]) ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]") result = ser == ser2 - expected = Series(False, index=ser.index) - tm.assert_series_equal(result, expected) + expected_eq = Series(False, index=ser.index) + tm.assert_series_equal(result, expected_eq) result = ser != ser2 - expected = Series(True, index=ser.index) - tm.assert_series_equal(result, expected) + expected_ne = Series(True, index=ser.index) + tm.assert_series_equal(result, expected_ne) with pytest.raises(TypeError, match="Invalid comparison"): ser > ser2 + + # GH#59505 + ser3 = ser2.astype("string[pyarrow]") + result3_eq = ser3 == ser + tm.assert_series_equal(result3_eq, expected_eq.astype("bool[pyarrow]")) + result3_ne = ser3 != ser + tm.assert_series_equal(result3_ne, expected_ne.astype("bool[pyarrow]")) + + with pytest.raises(TypeError, match="Invalid comparison"): + ser > ser3 + + ser4 = ser2.astype(ArrowDtype(pa.string())) + result4_eq = ser4 == ser + tm.assert_series_equal(result4_eq, expected_eq.astype("bool[pyarrow]")) + result4_ne = ser4 != ser + tm.assert_series_equal(result4_ne, expected_ne.astype("bool[pyarrow]")) + + with pytest.raises(TypeError, match="Invalid comparison"): + ser > ser4