Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG (string): ArrowEA comparisons with mismatched types #59505

Merged
merged 3 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ ExtensionArray
^^^^^^^^^^^^^^
- Bug in :meth:`.arrays.ArrowExtensionArray.__setitem__` which caused wrong behavior when using an integer array with repeated values as a key (:issue:`58530`)
- Bug in :meth:`api.types.is_datetime64_any_dtype` where a custom :class:`ExtensionDtype` would return ``False`` for array-likes (:issue:`57055`)
- Bug in comparison between object with :class:`ArrowDtype` and incompatible-dtyped (e.g. string vs bool) incorrectly raising instead of returning all-``False`` (for ``==``) or all-``True`` (for ``!=``) (:issue:`59505`)
- Bug in various :class:`DataFrame` reductions for pyarrow temporal dtypes returning incorrect dtype when result was null (:issue:`59234`)

Styler
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,13 @@ def _cmp_method(self, other, op) -> ArrowExtensionArray:
if isinstance(
other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)
) or isinstance(getattr(other, "dtype", None), CategoricalDtype):
result = pc_func(self._pa_array, self._box_pa(other))
try:
result = pc_func(self._pa_array, self._box_pa(other))
except pa.ArrowNotImplementedError:
# TODO: could this be wrong if other is object dtype?
# in which case we need to operate pointwise?
result = ops.invalid_comparison(self, other, op)
result = pa.array(result, type=pa.bool_())
elif is_scalar(other):
try:
result = pc_func(self._pa_array, self._box_pa(other))
Expand Down
6 changes: 1 addition & 5 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
BaseStringArray,
StringDtype,
)
from pandas.core.ops import invalid_comparison
from pandas.core.strings.object_array import ObjectStringArrayMixin

if not pa_version_under10p1:
Expand Down Expand Up @@ -563,10 +562,7 @@ def _convert_int_dtype(self, result):
return result

def _cmp_method(self, other, op):
try:
result = super()._cmp_method(other, op)
except pa.ArrowNotImplementedError:
return invalid_comparison(self, other, op)
result = super()._cmp_method(other, op)
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
Expand Down
31 changes: 26 additions & 5 deletions pandas/tests/series/test_logical_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pandas.compat import HAS_PYARROW

from pandas import (
ArrowDtype,
DataFrame,
Index,
Series,
Expand Down Expand Up @@ -523,18 +524,38 @@ def test_int_dtype_different_index_not_bool(self):
result = ser1 ^ ser2
tm.assert_series_equal(result, expected)

# TODO: this belongs in comparison tests
def test_pyarrow_numpy_string_invalid(self):
# GH#56008
pytest.importorskip("pyarrow")
pa = pytest.importorskip("pyarrow")
ser = Series([False, True])
ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]")
result = ser == ser2
expected = Series(False, index=ser.index)
tm.assert_series_equal(result, expected)
expected_eq = Series(False, index=ser.index)
tm.assert_series_equal(result, expected_eq)

result = ser != ser2
expected = Series(True, index=ser.index)
tm.assert_series_equal(result, expected)
expected_ne = Series(True, index=ser.index)
tm.assert_series_equal(result, expected_ne)

with pytest.raises(TypeError, match="Invalid comparison"):
ser > ser2

# GH#59505
ser3 = ser2.astype("string[pyarrow]")
result3_eq = ser3 == ser
tm.assert_series_equal(result3_eq, expected_eq.astype("bool[pyarrow]"))
result3_ne = ser3 != ser
tm.assert_series_equal(result3_ne, expected_ne.astype("bool[pyarrow]"))

with pytest.raises(TypeError, match="Invalid comparison"):
ser > ser3

ser4 = ser2.astype(ArrowDtype(pa.string()))
result4_eq = ser4 == ser
tm.assert_series_equal(result4_eq, expected_eq.astype("bool[pyarrow]"))
result4_ne = ser4 != ser
tm.assert_series_equal(result4_ne, expected_ne.astype("bool[pyarrow]"))

with pytest.raises(TypeError, match="Invalid comparison"):
ser > ser4