From a3b7c4569e1a09c97a6f91b1a799fefad64502e7 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 14 Oct 2020 14:47:13 -0700 Subject: [PATCH 1/7] ENH: IntervalArray comparisons --- pandas/core/arrays/_mixins.py | 2 - pandas/core/arrays/interval.py | 92 +++++++++++++++---- pandas/tests/arithmetic/test_interval.py | 5 - pandas/tests/extension/base/methods.py | 2 +- .../tests/indexes/interval/test_interval.py | 8 +- 5 files changed, 79 insertions(+), 30 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 95a003efbe1d0..b691f425b8436 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -3,7 +3,6 @@ import numpy as np from pandas._libs import lib -from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly, doc from pandas.util._validators import validate_fillna_kwargs @@ -139,7 +138,6 @@ def repeat(self: _T, repeats, axis=None) -> _T: -------- numpy.ndarray.repeat """ - nv.validate_repeat(tuple(), dict(axis=axis)) new_data = self._ndarray.repeat(repeats, axis=axis) return self._from_backing_data(new_data) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 09488b9576212..db1f3377b8cb9 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1,3 +1,4 @@ +import operator from operator import le, lt import textwrap from typing import TYPE_CHECKING, Optional, Tuple, Union, cast @@ -48,7 +49,7 @@ from pandas.core.construction import array, extract_array from pandas.core.indexers import check_array_indexer from pandas.core.indexes.base import ensure_index -from pandas.core.ops import unpack_zerodim_and_defer +from pandas.core.ops import invalid_comparison, unpack_zerodim_and_defer if TYPE_CHECKING: from pandas import Index @@ -520,8 +521,7 @@ def __setitem__(self, key, value): self._left[key] = value_left self._right[key] = value_right - @unpack_zerodim_and_defer("__eq__") - def __eq__(self, other): + def _cmp_method(self, other, op): # ensure pandas array for list-like and eliminate non-interval scalars if is_list_like(other): if len(self) != len(other): @@ -529,7 +529,7 @@ def __eq__(self, other): other = array(other) elif not isinstance(other, Interval): # non-interval scalar -> no matches - return np.zeros(len(self), dtype=bool) + return invalid_comparison(self, other, op) # determine the dtype of the elements we want to compare if isinstance(other, Interval): @@ -543,33 +543,87 @@ def __eq__(self, other): # extract intervals if we have interval categories with matching closed if is_interval_dtype(other_dtype): if self.closed != other.categories.closed: - return np.zeros(len(self), dtype=bool) + return invalid_comparison(self, other, op) other = other.categories.take(other.codes) # interval-like -> need same closed and matching endpoints if is_interval_dtype(other_dtype): if self.closed != other.closed: - return np.zeros(len(self), dtype=bool) - return (self._left == other.left) & (self._right == other.right) + return invalid_comparison(self, other, op) + if isinstance(other, Interval): + other = type(self)._from_sequence([other]) + if self._combined.dtype.kind in ["m", "M"]: + # Need to repeat bc we do not broadcast length-1 + # TODO: would be helpful to have a tile method to do + # this without copies + other = other.repeat(len(self)) + else: + other = type(self)(other) + + if op is operator.eq: + return (self._combined[:, 0] == other._left) & ( + self._combined[:, 1] == other._right + ) + elif op is operator.ne: + return (self._combined[:, 0] != other._left) | ( + self._combined[:, 1] != other._right + ) + elif op is operator.gt: + return (self._combined[:, 0] > other._combined[:, 0]) | ( + (self._combined[:, 0] == other._left) + & (self._combined[:, 1] > other._right) + ) + elif op is operator.ge: + return (self == other) | (self > other) + elif op is operator.lt: + return (self._combined[:, 0] < other._combined[:, 0]) | ( + (self._combined[:, 0] == other._left) + & (self._combined[:, 1] < other._right) + ) + else: + # operator.lt + return (self == other) | (self < other) # non-interval/non-object dtype -> no matches if not is_object_dtype(other_dtype): - return np.zeros(len(self), dtype=bool) + return invalid_comparison(self, other, op) # object dtype -> iteratively check for intervals - result = np.zeros(len(self), dtype=bool) - for i, obj in enumerate(other): - # need object to be an Interval with same closed and endpoints - if ( - isinstance(obj, Interval) - and self.closed == obj.closed - and self._left[i] == obj.left - and self._right[i] == obj.right - ): - result[i] = True - + try: + result = np.zeros(len(self), dtype=bool) + for i, obj in enumerate(other): + result[i] = op(self[i], obj) + except TypeError: + # pd.NA + result = np.zeros(len(self), dtype=object) + for i, obj in enumerate(other): + result[i] = op(self[i], obj) return result + @unpack_zerodim_and_defer("__eq__") + def __eq__(self, other): + return self._cmp_method(other, operator.eq) + + @unpack_zerodim_and_defer("__ne__") + def __ne__(self, other): + return self._cmp_method(other, operator.ne) + + @unpack_zerodim_and_defer("__gt__") + def __gt__(self, other): + return self._cmp_method(other, operator.gt) + + @unpack_zerodim_and_defer("__ge__") + def __ge__(self, other): + return self._cmp_method(other, operator.ge) + + @unpack_zerodim_and_defer("__lt__") + def __lt__(self, other): + return self._cmp_method(other, operator.lt) + + @unpack_zerodim_and_defer("__le__") + def __le__(self, other): + return self._cmp_method(other, operator.le) + def fillna(self, value=None, method=None, limit=None): """ Fill NA/NaN values using the specified method. diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py index 03cc4fe2bdcb5..8ab07a358ee51 100644 --- a/pandas/tests/arithmetic/test_interval.py +++ b/pandas/tests/arithmetic/test_interval.py @@ -216,11 +216,6 @@ def test_compare_list_like_nan(self, op, array, nulls_fixture, request): result = op(array, other) expected = self.elementwise_comparison(op, array, other) - if nulls_fixture is pd.NA and array.dtype.subtype != "i8": - reason = "broken for non-integer IntervalArray; see GH 31882" - mark = pytest.mark.xfail(reason=reason) - request.node.add_marker(mark) - tm.assert_numpy_array_equal(result, expected) @pytest.mark.parametrize( diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index 23e20a2c0903a..94533dcc08c48 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -443,7 +443,7 @@ def test_repeat(self, data, repeats, as_series, use_numpy): @pytest.mark.parametrize( "repeats, kwargs, error, msg", [ - (2, dict(axis=1), ValueError, "'axis"), + (2, dict(axis=1), ValueError, "axis"), (-1, dict(), ValueError, "negative"), ([1, 2], dict(), ValueError, "shape"), (2, dict(foo="bar"), TypeError, "'foo'"), diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 17a1c69858c11..e99b3dcc36fbb 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -579,9 +579,11 @@ def test_comparison(self): actual = self.index == self.index.left tm.assert_numpy_array_equal(actual, np.array([False, False])) - msg = ( - "not supported between instances of 'int' and " - "'pandas._libs.interval.Interval'" + msg = "|".join( + [ + "not supported between instances of 'int' and '.*.Interval'", + r"Invalid comparison between dtype=interval\[int64\] and ", + ] ) with pytest.raises(TypeError, match=msg): self.index > 0 From cf20846c33d435b4a797d6d3e29679a3f447d2da Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 14 Oct 2020 16:01:46 -0700 Subject: [PATCH 2/7] CLN: get IntervalIndex comparisons from IntervalArray --- pandas/core/indexes/interval.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index cb25ef1241ce0..969638367ae92 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -1105,19 +1105,6 @@ def _is_all_dates(self) -> bool: # TODO: arithmetic operations - # GH#30817 until IntervalArray implements inequalities, get them from Index - def __lt__(self, other): - return Index.__lt__(self, other) - - def __le__(self, other): - return Index.__le__(self, other) - - def __gt__(self, other): - return Index.__gt__(self, other) - - def __ge__(self, other): - return Index.__ge__(self, other) - def _is_valid_endpoint(endpoint) -> bool: """ From fcfe47d29eedb726344bc6e22ea307c49cf2c3cb Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 20 Oct 2020 10:34:17 -0700 Subject: [PATCH 3/7] update per requests --- pandas/core/arrays/interval.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index db1f3377b8cb9..77f60dc879ce7 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -552,7 +552,7 @@ def _cmp_method(self, other, op): return invalid_comparison(self, other, op) if isinstance(other, Interval): other = type(self)._from_sequence([other]) - if self._combined.dtype.kind in ["m", "M"]: + if self._left.dtype.kind in ["m", "M"]: # Need to repeat bc we do not broadcast length-1 # TODO: would be helpful to have a tile method to do # this without copies @@ -561,24 +561,18 @@ def _cmp_method(self, other, op): other = type(self)(other) if op is operator.eq: - return (self._combined[:, 0] == other._left) & ( - self._combined[:, 1] == other._right - ) + return (self._left == other._left) & (self._right == other._right) elif op is operator.ne: - return (self._combined[:, 0] != other._left) | ( - self._combined[:, 1] != other._right - ) + return (self._left != other._left) | (self._right != other._right) elif op is operator.gt: - return (self._combined[:, 0] > other._combined[:, 0]) | ( - (self._combined[:, 0] == other._left) - & (self._combined[:, 1] > other._right) + return (self._left > other._left) | ( + (self._left == other._left) & (self._right > other._right) ) elif op is operator.ge: return (self == other) | (self > other) elif op is operator.lt: - return (self._combined[:, 0] < other._combined[:, 0]) | ( - (self._combined[:, 0] == other._left) - & (self._combined[:, 1] < other._right) + return (self._left < other._left) | ( + (self._left == other._left) & (self._right < other._right) ) else: # operator.lt @@ -589,8 +583,8 @@ def _cmp_method(self, other, op): return invalid_comparison(self, other, op) # object dtype -> iteratively check for intervals + result = np.zeros(len(self), dtype=bool) try: - result = np.zeros(len(self), dtype=bool) for i, obj in enumerate(other): result[i] = op(self[i], obj) except TypeError: From fa6cecdebcc049f0ed278eff141e81b9740c3eb2 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 23 Oct 2020 18:01:17 -0700 Subject: [PATCH 4/7] Avoid having to tile --- pandas/core/arrays/interval.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 77f60dc879ce7..d88d5882ffc33 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -550,29 +550,22 @@ def _cmp_method(self, other, op): if is_interval_dtype(other_dtype): if self.closed != other.closed: return invalid_comparison(self, other, op) - if isinstance(other, Interval): - other = type(self)._from_sequence([other]) - if self._left.dtype.kind in ["m", "M"]: - # Need to repeat bc we do not broadcast length-1 - # TODO: would be helpful to have a tile method to do - # this without copies - other = other.repeat(len(self)) - else: + elif not isinstance(other, Interval): other = type(self)(other) if op is operator.eq: - return (self._left == other._left) & (self._right == other._right) + return (self._left == other.left) & (self._right == other.right) elif op is operator.ne: - return (self._left != other._left) | (self._right != other._right) + return (self._left != other.left) | (self._right != other.right) elif op is operator.gt: - return (self._left > other._left) | ( - (self._left == other._left) & (self._right > other._right) + return (self._left > other.left) | ( + (self._left == other.left) & (self._right > other.right) ) elif op is operator.ge: return (self == other) | (self > other) elif op is operator.lt: - return (self._left < other._left) | ( - (self._left == other._left) & (self._right < other._right) + return (self._left < other.left) | ( + (self._left == other.left) & (self._right < other.right) ) else: # operator.lt From ff640ea8f8ca52059ed85aac1e71ad0955fd97c1 Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 24 Oct 2020 15:12:10 -0700 Subject: [PATCH 5/7] handle NA per suggestion --- pandas/core/arrays/interval.py | 16 +++++++++------- pandas/tests/arithmetic/test_interval.py | 5 +++++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index d88d5882ffc33..dae3b70ed6b74 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -13,6 +13,7 @@ IntervalMixin, intervals_to_interval_bounds, ) +from pandas._libs.missing import NA from pandas._typing import ArrayLike, Dtype from pandas.compat.numpy import function as nv from pandas.util._decorators import Appender @@ -577,14 +578,15 @@ def _cmp_method(self, other, op): # object dtype -> iteratively check for intervals result = np.zeros(len(self), dtype=bool) - try: - for i, obj in enumerate(other): - result[i] = op(self[i], obj) - except TypeError: - # pd.NA - result = np.zeros(len(self), dtype=object) - for i, obj in enumerate(other): + for i, obj in enumerate(other): + try: result[i] = op(self[i], obj) + except TypeError: + if obj is NA: + # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092 + result[i] = op is operator.ne + else: + raise return result @unpack_zerodim_and_defer("__eq__") diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py index b7011ed223166..30a23d8563ef8 100644 --- a/pandas/tests/arithmetic/test_interval.py +++ b/pandas/tests/arithmetic/test_interval.py @@ -216,6 +216,11 @@ def test_compare_list_like_nan(self, op, array, nulls_fixture, request): result = op(array, other) expected = self.elementwise_comparison(op, array, other) + if nulls_fixture is pd.NA and array.dtype.subtype != "i8": + reason = "broken for non-integer IntervalArray; see GH 31882" + mark = pytest.mark.xfail(reason=reason) + request.node.add_marker(mark) + tm.assert_numpy_array_equal(result, expected) @pytest.mark.parametrize( From 247ce9083eab161be84cebef752c318ac0c567cf Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 27 Oct 2020 11:08:38 -0700 Subject: [PATCH 6/7] comment --- pandas/core/arrays/_mixins.py | 2 ++ pandas/core/arrays/interval.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 4f49cd6b3faf2..948ffdc1f7c01 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -3,6 +3,7 @@ import numpy as np from pandas._libs import lib +from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly, doc from pandas.util._validators import validate_fillna_kwargs @@ -138,6 +139,7 @@ def repeat(self: _T, repeats, axis=None) -> _T: -------- numpy.ndarray.repeat """ + nv.validate_repeat(tuple(), dict(axis=axis)) new_data = self._ndarray.repeat(repeats, axis=axis) return self._from_backing_data(new_data) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index dae3b70ed6b74..b943c811c54f9 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -583,6 +583,8 @@ def _cmp_method(self, other, op): result[i] = op(self[i], obj) except TypeError: if obj is NA: + # comparison returns NA, which we (for now?) treat like + # other NAs # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092 result[i] = op is operator.ne else: From badb99d0a6873ce717cd621b785ac90edb159338 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 2 Nov 2020 12:22:34 -0800 Subject: [PATCH 7/7] update comment --- pandas/core/arrays/interval.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 3488d35e3518a..f8ece2a9fe7d4 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -586,8 +586,7 @@ def _cmp_method(self, other, op): result[i] = op(self[i], obj) except TypeError: if obj is NA: - # comparison returns NA, which we (for now?) treat like - # other NAs + # comparison with np.nan returns NA # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092 result[i] = op is operator.ne else: