Skip to content

Commit

Permalink
ENH: .equals for Extension Arrays (#30652)
Browse files Browse the repository at this point in the history
Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
  • Loading branch information
dwhu and jorisvandenbossche authored May 9, 2020
1 parent 6388370 commit f21bc99
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 9 deletions.
1 change: 1 addition & 0 deletions doc/source/reference/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ objects.
api.extensions.ExtensionArray.copy
api.extensions.ExtensionArray.view
api.extensions.ExtensionArray.dropna
api.extensions.ExtensionArray.equals
api.extensions.ExtensionArray.factorize
api.extensions.ExtensionArray.fillna
api.extensions.ExtensionArray.isna
Expand Down
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ Other enhancements
such as ``dict`` and ``list``, mirroring the behavior of :meth:`DataFrame.update` (:issue:`33215`)
- :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`)
- :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`)
- The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals`
method, similarly to :meth:`Series.equals` (:issue:`27081`).
-

.. ---------------------------------------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion pandas/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,9 @@ def box_expected(expected, box_cls, transpose=True):
-------
subclass of box_cls
"""
if box_cls is pd.Index:
if box_cls is pd.array:
expected = pd.array(expected)
elif box_cls is pd.Index:
expected = pd.Index(expected)
elif box_cls is pd.Series:
expected = pd.Series(expected)
Expand Down
58 changes: 55 additions & 3 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ExtensionArray:
dropna
factorize
fillna
equals
isna
ravel
repeat
Expand All @@ -84,6 +85,7 @@ class ExtensionArray:
* _from_factorized
* __getitem__
* __len__
* __eq__
* dtype
* nbytes
* isna
Expand Down Expand Up @@ -333,6 +335,24 @@ def __iter__(self):
for i in range(len(self)):
yield self[i]

def __eq__(self, other: Any) -> ArrayLike:
"""
Return for `self == other` (element-wise equality).
"""
# Implementer note: this should return a boolean numpy ndarray or
# a boolean ExtensionArray.
# When `other` is one of Series, Index, or DataFrame, this method should
# return NotImplemented (to ensure that those objects are responsible for
# first unpacking the arrays, and then dispatch the operation to the
# underlying arrays)
raise AbstractMethodError(self)

def __ne__(self, other: Any) -> ArrayLike:
"""
Return for `self != other` (element-wise in-equality).
"""
return ~(self == other)

def to_numpy(
self, dtype=None, copy: bool = False, na_value=lib.no_default
) -> np.ndarray:
Expand Down Expand Up @@ -682,6 +702,38 @@ def searchsorted(self, value, side="left", sorter=None):
arr = self.astype(object)
return arr.searchsorted(value, side=side, sorter=sorter)

def equals(self, other: "ExtensionArray") -> bool:
"""
Return if another array is equivalent to this array.
Equivalent means that both arrays have the same shape and dtype, and
all values compare equal. Missing values in the same location are
considered equal (in contrast with normal equality).
Parameters
----------
other : ExtensionArray
Array to compare to this Array.
Returns
-------
boolean
Whether the arrays are equivalent.
"""
if not type(self) == type(other):
return False
elif not self.dtype == other.dtype:
return False
elif not len(self) == len(other):
return False
else:
equal_values = self == other
if isinstance(equal_values, ExtensionArray):
# boolean array with NA -> fill with False
equal_values = equal_values.fillna(False)
equal_na = self.isna() & other.isna()
return (equal_values | equal_na).all().item()

def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
"""
Return an array and missing value suitable for factorization.
Expand Down Expand Up @@ -1134,7 +1186,7 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin):
"""

@classmethod
def _create_method(cls, op, coerce_to_dtype=True):
def _create_method(cls, op, coerce_to_dtype=True, result_dtype=None):
"""
A class method that returns a method that will correspond to an
operator for an ExtensionArray subclass, by dispatching to the
Expand Down Expand Up @@ -1202,7 +1254,7 @@ def _maybe_convert(arr):
# exception raised in _from_sequence; ensure we have ndarray
res = np.asarray(arr)
else:
res = np.asarray(arr)
res = np.asarray(arr, dtype=result_dtype)
return res

if op.__name__ in {"divmod", "rdivmod"}:
Expand All @@ -1220,4 +1272,4 @@ def _create_arithmetic_method(cls, op):

@classmethod
def _create_comparison_method(cls, op):
return cls._create_method(op, coerce_to_dtype=False)
return cls._create_method(op, coerce_to_dtype=False, result_dtype=bool)
3 changes: 0 additions & 3 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,6 @@ def __eq__(self, other):

return result

def __ne__(self, other):
return ~self.__eq__(other)

def fillna(self, value=None, method=None, limit=None):
"""
Fill NA/NaN values using the specified method.
Expand Down
3 changes: 3 additions & 0 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,9 @@ def where(

return [self.make_block_same_class(result, placement=self.mgr_locs)]

def equals(self, other) -> bool:
return self.values.equals(other.values)

def _unstack(self, unstacker, fill_value, new_placement):
# ExtensionArray-safe unstack.
# We override ObjectBlock._unstack, which unstacks directly on the
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/arrays/integer/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,13 @@ def test_compare_to_int(self, any_nullable_int_dtype, all_compare_operators):
expected[s2.isna()] = pd.NA

self.assert_series_equal(result, expected)


def test_equals():
# GH-30652
# equals is generally tested in /tests/extension/base/methods, but this
# specifically tests that two arrays of the same class but different dtype
# do not evaluate equal
a1 = pd.array([1, 2, None], dtype="Int64")
a2 = pd.array([1, 2, None], dtype="Int32")
assert a1.equals(a2) is False
29 changes: 29 additions & 0 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,32 @@ def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy):
np.repeat(data, repeats, **kwargs)
else:
data.repeat(repeats, **kwargs)

@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
def test_equals(self, data, na_value, as_series, box):
data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype)
data_na = type(data)._from_sequence([na_value] * len(data), dtype=data.dtype)

data = tm.box_expected(data, box, transpose=False)
data2 = tm.box_expected(data2, box, transpose=False)
data_na = tm.box_expected(data_na, box, transpose=False)

# we are asserting with `is True/False` explicitly, to test that the
# result is an actual Python bool, and not something "truthy"

assert data.equals(data) is True
assert data.equals(data.copy()) is True

# unequal other data
assert data.equals(data2) is False
assert data.equals(data_na) is False

# different length
assert data[:2].equals(data[:3]) is False

# emtpy are equal
assert data[:0].equals(data[:0]) is True

# other types
assert data.equals(None) is False
assert data[[0]].equals(data[0]) is False
8 changes: 6 additions & 2 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,8 @@ class BaseComparisonOpsTests(BaseOpsUtil):
def _compare_other(self, s, data, op_name, other):
op = self.get_op_from_name(op_name)
if op_name == "__eq__":
assert getattr(data, op_name)(other) is NotImplemented
assert not op(s, other).all()
elif op_name == "__ne__":
assert getattr(data, op_name)(other) is NotImplemented
assert op(s, other).all()

else:
Expand Down Expand Up @@ -176,6 +174,12 @@ def test_direct_arith_with_series_returns_not_implemented(self, data):
else:
raise pytest.skip(f"{type(data).__name__} does not implement __eq__")

if hasattr(data, "__ne__"):
result = data.__ne__(other)
assert result is NotImplemented
else:
raise pytest.skip(f"{type(data).__name__} does not implement __ne__")


class BaseUnaryOpsTests(BaseOpsUtil):
def test_invert(self, data):
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/extension/json/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def __setitem__(self, key, value):
def __len__(self) -> int:
return len(self.data)

def __eq__(self, other):
return NotImplemented

def __ne__(self, other):
return NotImplemented

def __array__(self, dtype=None):
if dtype is None:
dtype = object
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def test_where_series(self, data, na_value):
def test_searchsorted(self, data_for_sorting):
super().test_searchsorted(data_for_sorting)

@pytest.mark.skip(reason="Can't compare dicts.")
def test_equals(self, data, na_value, as_series):
pass


class TestCasting(BaseJSON, base.BaseCastingTests):
@pytest.mark.skip(reason="failing on np.array(self, dtype=str)")
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ def test_repeat(self, data, repeats, as_series, use_numpy):
def test_diff(self, data, periods):
return super().test_diff(data, periods)

@skip_nested
@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
def test_equals(self, data, na_value, as_series, box):
# Fails creating with _from_sequence
super().test_equals(data, na_value, as_series, box)


@skip_nested
class TestArithmetics(BaseNumPyTests, base.BaseArithmeticOpsTests):
Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ def test_shift_0_periods(self, data):
data._sparse_values[0] = data._sparse_values[1]
assert result._sparse_values[0] != result._sparse_values[1]

@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
def test_equals(self, data, na_value, as_series, box):
self._check_unsupported(data)
super().test_equals(data, na_value, as_series, box)


class TestCasting(BaseSparseTests, base.BaseCastingTests):
def test_astype_object_series(self, all_data):
Expand Down

0 comments on commit f21bc99

Please sign in to comment.