Skip to content

BUG: Fix series.round() handling of EA #26817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
10 changes: 7 additions & 3 deletions pandas/compat/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
""" support numpy compatiblitiy across versions """

import re
import numpy as np
from distutils.version import LooseVersion
import re

import numpy as np

# numpy versioning
_np_version = np.__version__
Expand All @@ -13,7 +13,10 @@
_np_version_under1p16 = _nlv < LooseVersion('1.16')
_np_version_under1p17 = _nlv < LooseVersion('1.17')
_is_numpy_dev = '.dev' in str(_nlv)

try:
_NEP18_enabled = np.core.overrides.ENABLE_ARRAY_FUNCTION
except Exception:
_NEP18_enabled = False

if _nlv < '1.13.3':
raise ImportError('this version of pandas is incompatible with '
Expand Down Expand Up @@ -62,6 +65,7 @@ def np_array_datetime64_compat(arr, *args, **kwargs):


__all__ = ['np',
'_np_version',
'_np_version_under1p14',
'_np_version_under1p15',
'_np_version_under1p16',
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,7 +2098,7 @@ def round(self, decimals=0, *args, **kwargs):
dtype: float64
"""
nv.validate_round(args, kwargs)
result = com.values_from_object(self).round(decimals)
result = np.round(self.array, decimals=decimals)
result = self._constructor(result, index=self.index).__finalize__(self)

return result
Expand Down
9 changes: 4 additions & 5 deletions pandas/tests/arrays/sparse/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest

from pandas._libs.sparse import IntIndex
from pandas.compat.numpy import _np_version_under1p16
import pandas.util._test_decorators as td

import pandas as pd
Expand Down Expand Up @@ -175,8 +174,8 @@ def test_constructor_inferred_fill_value(self, data, fill_value):
@pytest.mark.parametrize('format', ['coo', 'csc', 'csr'])
@pytest.mark.parametrize('size', [
pytest.param(0,
marks=pytest.mark.skipif(_np_version_under1p16,
reason='NumPy-11383')),
marks=td.skip_if_np_lt("1.16",
reason='NumPy-11383')),
10
])
@td.skip_if_no_scipy
Expand Down Expand Up @@ -870,7 +869,7 @@ def test_all(self, data, pos, neg):
([1, 2, 1], 1, 0),
([1.0, 2.0, 1.0], 1.0, 0.0)
])
@td.skip_if_np_lt_115 # prior didn't dispatch
@td.skip_if_np_lt("1.15") # prior didn't dispatch
def test_numpy_all(self, data, pos, neg):
# GH 17570
out = np.all(SparseArray(data))
Expand Down Expand Up @@ -916,7 +915,7 @@ def test_any(self, data, pos, neg):
([0, 2, 0], 2, 0),
([0.0, 2.0, 0.0], 2.0, 0.0)
])
@td.skip_if_np_lt_115 # prior didn't dispatch
@td.skip_if_np_lt("1.15") # prior didn't dispatch
def test_numpy_any(self, data, pos, neg):
# GH 17570
out = np.any(SparseArray(data))
Expand Down
36 changes: 36 additions & 0 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pandas.api.extensions import register_extension_dtype
from pandas.core.arrays import ExtensionArray, ExtensionScalarOpsMixin

_should_cast_results = [np.repeat]


@register_extension_dtype
class DecimalDtype(ExtensionDtype):
Expand Down Expand Up @@ -153,6 +155,40 @@ def _reduce(self, name, skipna=True, **kwargs):
"the {} operation".format(name))
return op(axis=0)

# numpy experimental NEP-18 (opt-in numpy 1.16, enabled in in 1.17)
def __array_function__(self, func, types, args, kwargs):
def coerce_EA(coll):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem like a good idea. For now, just use np.asarray. I think we have a separate issue about converting EAs to preferred ndarray representation, but __array__ is the best we have for now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've already said I'm not thrilled with this approach, but I'm not sure what you're suggesting.
How can you support pd.concat([EA,EA]) without resorting to something like this?

__array__ in relation to what? I've already discussed how implementing NEP-18 on a class, makes numpy ignore __array__, since point 4 in he first post.

# In order to delegate to numpy, we have to coerce any
# ExtensionArrays to the best numpy-friendly dtype approximation
# Different functions take different arguments, which may be
# nested collections, so we look at everything. Sigh.
for i in range(len(coll)):
if isinstance(coll[i], (tuple, list)):
coll[i] = coerce_EA(list(coll[i]))
else:
if isinstance(coll[i], DecimalArray):
# TODO: how to check for any ndarray-like with
# non-numpy dtype?
coll[i] = np.array(coll[i], dtype=object)

return coll

if func is np.round_:
values = [decimal.Decimal(round(_))
for _ in self._data]
return DecimalArray(values, dtype=self.dtype)

elif True: # just assume we can handle all functions
args = coerce_EA(list(args))
result = func(*args, **kwargs)

if func in _should_cast_results:
result = pd.array(result, dtype=self.dtype)

return result
else:
return NotImplemented


def to_decimal(values, context=None):
return DecimalArray([decimal.Decimal(x) for x in values], context=context)
Expand Down
22 changes: 22 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

import pandas as pd
from pandas.tests.extension import base
import pandas.util.testing as tm
Expand Down Expand Up @@ -380,6 +382,26 @@ def test_divmod_array(reverse, expected_div, expected_mod):
tm.assert_extension_array_equal(mod, expected_mod)


# numpy 1.17 has NEP-18 on by default
# for numpy 1.16 set shell variable with
# "export NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1"
# before running pytest/python.
# verify by checking value of `np.core.overrides.ENABLE_ARRAY_FUNCTION`
@td.skip_if_no_NEP18
def test_series_round():
ser = pd.Series(to_decimal([1.1, 2.4, 3.1])).round()
expected = pd.Series(to_decimal([1, 2, 3]))
tm.assert_extension_array_equal(ser.array, expected.array)
tm.assert_series_equal(ser, expected)


@td.skip_if_no_NEP18
def test_series_round_then_sum():
result = pd.Series(to_decimal([1.1, 2.4, 3.1])).round().sum(skipna=False)
expected = decimal.Decimal("6")
assert result == expected


def test_formatting_values_deprecated():
class DecimalArray2(DecimalArray):
def _formatting_values(self):
Expand Down
16 changes: 8 additions & 8 deletions pandas/tests/frame/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,21 +1565,21 @@ def test_any_all_bool_only(self):
(np.all, {'A': pd.Series([0, 1], dtype=int)}, False),
(np.any, {'A': pd.Series([0, 1], dtype=int)}, True),
pytest.param(np.all, {'A': pd.Series([0, 1], dtype='M8[ns]')}, False,
marks=[td.skip_if_np_lt_115]),
marks=[td.skip_if_np_lt("1.15")]),
pytest.param(np.any, {'A': pd.Series([0, 1], dtype='M8[ns]')}, True,
marks=[td.skip_if_np_lt_115]),
marks=[td.skip_if_np_lt("1.15")]),
pytest.param(np.all, {'A': pd.Series([1, 2], dtype='M8[ns]')}, True,
marks=[td.skip_if_np_lt_115]),
marks=[td.skip_if_np_lt("1.15")]),
pytest.param(np.any, {'A': pd.Series([1, 2], dtype='M8[ns]')}, True,
marks=[td.skip_if_np_lt_115]),
marks=[td.skip_if_np_lt("1.15")]),
pytest.param(np.all, {'A': pd.Series([0, 1], dtype='m8[ns]')}, False,
marks=[td.skip_if_np_lt_115]),
marks=[td.skip_if_np_lt("1.15")]),
pytest.param(np.any, {'A': pd.Series([0, 1], dtype='m8[ns]')}, True,
marks=[td.skip_if_np_lt_115]),
marks=[td.skip_if_np_lt("1.15")]),
pytest.param(np.all, {'A': pd.Series([1, 2], dtype='m8[ns]')}, True,
marks=[td.skip_if_np_lt_115]),
marks=[td.skip_if_np_lt("1.15")]),
pytest.param(np.any, {'A': pd.Series([1, 2], dtype='m8[ns]')}, True,
marks=[td.skip_if_np_lt_115]),
marks=[td.skip_if_np_lt("1.15")]),
(np.all, {'A': pd.Series([0, 1], dtype='category')}, False),
(np.any, {'A': pd.Series([0, 1], dtype='category')}, True),
(np.all, {'A': pd.Series([1, 2], dtype='category')}, True),
Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/series/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ def test_value_counts_categorical_not_ordered(self):
dict(keepdims=True),
dict(out=object()),
])
@td.skip_if_np_lt_115
@td.skip_if_np_lt("1.15")
def test_validate_any_all_out_keepdims_raises(self, kwargs, func):
s = pd.Series([1, 2])
param = list(kwargs)[0]
Expand All @@ -1117,7 +1117,7 @@ def test_validate_any_all_out_keepdims_raises(self, kwargs, func):
with pytest.raises(ValueError, match=msg):
func(s, **kwargs)

@td.skip_if_np_lt_115
@td.skip_if_np_lt("1.15")
def test_validate_sum_initial(self):
s = pd.Series([1, 2])
msg = (r"the 'initial' parameter is not "
Expand All @@ -1136,7 +1136,7 @@ def test_validate_median_initial(self):
# method instead of the ufunc.
s.median(overwrite_input=True)

@td.skip_if_np_lt_115
@td.skip_if_np_lt("1.15")
def test_validate_stat_keepdims(self):
s = pd.Series([1, 2])
msg = (r"the 'keepdims' parameter is not "
Expand Down
14 changes: 11 additions & 3 deletions pandas/util/_test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ def test_foo():

For more information, refer to the ``pytest`` documentation on ``skipif``.
"""
from distutils.version import LooseVersion
import locale
from typing import Optional

from _pytest.mark.structures import MarkDecorator
import pytest

from pandas.compat import is_platform_32bit, is_platform_windows
from pandas.compat.numpy import _np_version_under1p15
from pandas.compat.numpy import _NEP18_enabled, _np_version

from pandas.core.computation.expressions import (
_NUMEXPR_INSTALLED, _USE_NUMEXPR)
Expand Down Expand Up @@ -142,8 +143,6 @@ def skip_if_no(

skip_if_no_mpl = pytest.mark.skipif(_skip_if_no_mpl(),
reason="Missing matplotlib dependency")
skip_if_np_lt_115 = pytest.mark.skipif(_np_version_under1p15,
reason="NumPy 1.15 or greater required")
skip_if_mpl = pytest.mark.skipif(not _skip_if_no_mpl(),
reason="matplotlib is present")
skip_if_32bit = pytest.mark.skipif(is_platform_32bit(),
Expand All @@ -166,6 +165,15 @@ def skip_if_no(
"installed->{installed}".format(
enabled=_USE_NUMEXPR,
installed=_NUMEXPR_INSTALLED))
skip_if_no_NEP18 = pytest.mark.skipif(not _NEP18_enabled,
reason="numpy NEP-18 disabled")


def skip_if_np_lt(ver_str, reason=None, *args, **kwds):
if reason is None:
reason = "NumPy %s or greater required" % ver_str
return pytest.mark.skipif(_np_version < LooseVersion(ver_str),
reason=reason, *args, **kwds)


def parametrize_fixture_doc(*args):
Expand Down