Skip to content

Commit

Permalink
BUG (string): ArrowEA comparisons with mismatched types (pandas-dev#5…
Browse files Browse the repository at this point in the history
…9505)

* BUG: ArrowEA comparisons with mismatched types

* move whatsnew

* GH ref
  • Loading branch information
jbrockmendel authored and jorisvandenbossche committed Oct 2, 2024
1 parent 4dcc226 commit d755421
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
8 changes: 7 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,13 @@ def _cmp_method(self, other, op):
if isinstance(
other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)
) or isinstance(getattr(other, "dtype", None), CategoricalDtype):
result = pc_func(self._pa_array, self._box_pa(other))
try:
result = pc_func(self._pa_array, self._box_pa(other))
except pa.ArrowNotImplementedError:
# TODO: could this be wrong if other is object dtype?
# in which case we need to operate pointwise?
result = ops.invalid_comparison(self, other, op)
result = pa.array(result, type=pa.bool_())
elif is_scalar(other):
try:
result = pc_func(self._pa_array, self._box_pa(other))
Expand Down
6 changes: 1 addition & 5 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
BaseStringArray,
StringDtype,
)
from pandas.core.ops import invalid_comparison
from pandas.core.strings.object_array import ObjectStringArrayMixin

if not pa_version_under10p1:
Expand Down Expand Up @@ -565,10 +564,7 @@ def _convert_int_dtype(self, result):
return result

def _cmp_method(self, other, op):
try:
result = super()._cmp_method(other, op)
except pa.ArrowNotImplementedError:
return invalid_comparison(self, other, op)
result = super()._cmp_method(other, op)
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
Expand Down
31 changes: 26 additions & 5 deletions pandas/tests/series/test_logical_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pandas.compat import HAS_PYARROW

from pandas import (
ArrowDtype,
DataFrame,
Index,
Series,
Expand Down Expand Up @@ -539,18 +540,38 @@ def test_int_dtype_different_index_not_bool(self):
result = ser1 ^ ser2
tm.assert_series_equal(result, expected)

# TODO: this belongs in comparison tests
def test_pyarrow_numpy_string_invalid(self):
# GH#56008
pytest.importorskip("pyarrow")
pa = pytest.importorskip("pyarrow")
ser = Series([False, True])
ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]")
result = ser == ser2
expected = Series(False, index=ser.index)
tm.assert_series_equal(result, expected)
expected_eq = Series(False, index=ser.index)
tm.assert_series_equal(result, expected_eq)

result = ser != ser2
expected = Series(True, index=ser.index)
tm.assert_series_equal(result, expected)
expected_ne = Series(True, index=ser.index)
tm.assert_series_equal(result, expected_ne)

with pytest.raises(TypeError, match="Invalid comparison"):
ser > ser2

# GH#59505
ser3 = ser2.astype("string[pyarrow]")
result3_eq = ser3 == ser
tm.assert_series_equal(result3_eq, expected_eq.astype("bool[pyarrow]"))
result3_ne = ser3 != ser
tm.assert_series_equal(result3_ne, expected_ne.astype("bool[pyarrow]"))

with pytest.raises(TypeError, match="Invalid comparison"):
ser > ser3

ser4 = ser2.astype(ArrowDtype(pa.string()))
result4_eq = ser4 == ser
tm.assert_series_equal(result4_eq, expected_eq.astype("bool[pyarrow]"))
result4_ne = ser4 != ser
tm.assert_series_equal(result4_ne, expected_ne.astype("bool[pyarrow]"))

with pytest.raises(TypeError, match="Invalid comparison"):
ser > ser4

0 comments on commit d755421

Please sign in to comment.