diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 161cf3bf3a677..f8ece2a9fe7d4 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1,3 +1,4 @@ +import operator from operator import le, lt import textwrap from typing import TYPE_CHECKING, Optional, Tuple, Union, cast @@ -12,6 +13,7 @@ IntervalMixin, intervals_to_interval_bounds, ) +from pandas._libs.missing import NA from pandas._typing import ArrayLike, Dtype from pandas.compat.numpy import function as nv from pandas.util._decorators import Appender @@ -48,7 +50,7 @@ from pandas.core.construction import array, extract_array from pandas.core.indexers import check_array_indexer from pandas.core.indexes.base import ensure_index -from pandas.core.ops import unpack_zerodim_and_defer +from pandas.core.ops import invalid_comparison, unpack_zerodim_and_defer if TYPE_CHECKING: from pandas import Index @@ -520,8 +522,7 @@ def __setitem__(self, key, value): self._left[key] = value_left self._right[key] = value_right - @unpack_zerodim_and_defer("__eq__") - def __eq__(self, other): + def _cmp_method(self, other, op): # ensure pandas array for list-like and eliminate non-interval scalars if is_list_like(other): if len(self) != len(other): @@ -529,7 +530,7 @@ def __eq__(self, other): other = array(other) elif not isinstance(other, Interval): # non-interval scalar -> no matches - return np.zeros(len(self), dtype=bool) + return invalid_comparison(self, other, op) # determine the dtype of the elements we want to compare if isinstance(other, Interval): @@ -543,7 +544,8 @@ def __eq__(self, other): # extract intervals if we have interval categories with matching closed if is_interval_dtype(other_dtype): if self.closed != other.categories.closed: - return np.zeros(len(self), dtype=bool) + return invalid_comparison(self, other, op) + other = other.categories.take( other.codes, allow_fill=True, fill_value=other.categories._na_value ) @@ -551,27 +553,70 @@ def __eq__(self, other): # interval-like -> need same closed and matching endpoints if is_interval_dtype(other_dtype): if self.closed != other.closed: - return np.zeros(len(self), dtype=bool) - return (self._left == other.left) & (self._right == other.right) + return invalid_comparison(self, other, op) + elif not isinstance(other, Interval): + other = type(self)(other) + + if op is operator.eq: + return (self._left == other.left) & (self._right == other.right) + elif op is operator.ne: + return (self._left != other.left) | (self._right != other.right) + elif op is operator.gt: + return (self._left > other.left) | ( + (self._left == other.left) & (self._right > other.right) + ) + elif op is operator.ge: + return (self == other) | (self > other) + elif op is operator.lt: + return (self._left < other.left) | ( + (self._left == other.left) & (self._right < other.right) + ) + else: + # operator.lt + return (self == other) | (self < other) # non-interval/non-object dtype -> no matches if not is_object_dtype(other_dtype): - return np.zeros(len(self), dtype=bool) + return invalid_comparison(self, other, op) # object dtype -> iteratively check for intervals result = np.zeros(len(self), dtype=bool) for i, obj in enumerate(other): - # need object to be an Interval with same closed and endpoints - if ( - isinstance(obj, Interval) - and self.closed == obj.closed - and self._left[i] == obj.left - and self._right[i] == obj.right - ): - result[i] = True - + try: + result[i] = op(self[i], obj) + except TypeError: + if obj is NA: + # comparison with np.nan returns NA + # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092 + result[i] = op is operator.ne + else: + raise return result + @unpack_zerodim_and_defer("__eq__") + def __eq__(self, other): + return self._cmp_method(other, operator.eq) + + @unpack_zerodim_and_defer("__ne__") + def __ne__(self, other): + return self._cmp_method(other, operator.ne) + + @unpack_zerodim_and_defer("__gt__") + def __gt__(self, other): + return self._cmp_method(other, operator.gt) + + @unpack_zerodim_and_defer("__ge__") + def __ge__(self, other): + return self._cmp_method(other, operator.ge) + + @unpack_zerodim_and_defer("__lt__") + def __lt__(self, other): + return self._cmp_method(other, operator.lt) + + @unpack_zerodim_and_defer("__le__") + def __le__(self, other): + return self._cmp_method(other, operator.le) + def fillna(self, value=None, method=None, limit=None): """ Fill NA/NaN values using the specified method. diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 1bd71f00b534d..2061e652a4c01 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -1074,19 +1074,6 @@ def _is_all_dates(self) -> bool: # TODO: arithmetic operations - # GH#30817 until IntervalArray implements inequalities, get them from Index - def __lt__(self, other): - return Index.__lt__(self, other) - - def __le__(self, other): - return Index.__le__(self, other) - - def __gt__(self, other): - return Index.__gt__(self, other) - - def __ge__(self, other): - return Index.__ge__(self, other) - def _is_valid_endpoint(endpoint) -> bool: """ diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index e973b1247941f..29a59cdefbd83 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -447,7 +447,7 @@ def test_repeat(self, data, repeats, as_series, use_numpy): @pytest.mark.parametrize( "repeats, kwargs, error, msg", [ - (2, dict(axis=1), ValueError, "'axis"), + (2, dict(axis=1), ValueError, "axis"), (-1, dict(), ValueError, "negative"), ([1, 2], dict(), ValueError, "shape"), (2, dict(foo="bar"), TypeError, "'foo'"), diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 67e031b53e44e..157446b1fff5d 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -579,9 +579,11 @@ def test_comparison(self): actual = self.index == self.index.left tm.assert_numpy_array_equal(actual, np.array([False, False])) - msg = ( - "not supported between instances of 'int' and " - "'pandas._libs.interval.Interval'" + msg = "|".join( + [ + "not supported between instances of 'int' and '.*.Interval'", + r"Invalid comparison between dtype=interval\[int64\] and ", + ] ) with pytest.raises(TypeError, match=msg): self.index > 0