Skip to content

Commit 6fe6832

Browse files
DGradyjreback
authored andcommitted
BUG: Fix behavior of argmax and argmin with inf (#16449) (#16449)
Closes #13595
1 parent 47b3973 commit 6fe6832

File tree

4 files changed

+54
-6
lines changed

4 files changed

+54
-6
lines changed

doc/source/whatsnew/v0.21.0.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ Other API Changes
266266
- Removed the ``@slow`` decorator from ``pandas.util.testing``, which caused issues for some downstream packages' test suites. Use ``@pytest.mark.slow`` instead, which achieves the same thing (:issue:`16850`)
267267
- Moved definition of ``MergeError`` to the ``pandas.errors`` module.
268268
- The signature of :func:`Series.set_axis` and :func:`DataFrame.set_axis` has been changed from ``set_axis(axis, labels)`` to ``set_axis(labels, axis=0)``, for consistency with the rest of the API. The old signature is deprecated and will show a ``FutureWarning`` (:issue:`14636`)
269-
269+
- :func:`Series.argmin` and :func:`Series.argmax` will now raise a ``TypeError`` when used with ``object`` dtypes, instead of a ``ValueError`` (:issue:`13595`)
270270

271271
.. _whatsnew_0210.deprecations:
272272

@@ -374,6 +374,7 @@ Reshaping
374374
- Fixes regression from 0.20, :func:`Series.aggregate` and :func:`DataFrame.aggregate` allow dictionaries as return values again (:issue:`16741`)
375375
- Fixes dtype of result with integer dtype input, from :func:`pivot_table` when called with ``margins=True`` (:issue:`17013`)
376376
- Bug in :func:`crosstab` where passing two ``Series`` with the same name raised a ``KeyError`` (:issue:`13279`)
377+
- :func:`Series.argmin`, :func:`Series.argmax`, and their counterparts on ``DataFrame`` and groupby objects work correctly with floating point data that contains infinite values (:issue:`13595`).
377378

378379
Numeric
379380
^^^^^^^

pandas/core/nanops.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -486,23 +486,23 @@ def reduction(values, axis=None, skipna=True):
486486
nanmax = _nanminmax('max', fill_value_typ='-inf')
487487

488488

489+
@disallow('O')
489490
def nanargmax(values, axis=None, skipna=True):
490491
"""
491492
Returns -1 in the NA case
492493
"""
493-
values, mask, dtype, _ = _get_values(values, skipna, fill_value_typ='-inf',
494-
isfinite=True)
494+
values, mask, dtype, _ = _get_values(values, skipna, fill_value_typ='-inf')
495495
result = values.argmax(axis)
496496
result = _maybe_arg_null_out(result, axis, mask, skipna)
497497
return result
498498

499499

500+
@disallow('O')
500501
def nanargmin(values, axis=None, skipna=True):
501502
"""
502503
Returns -1 in the NA case
503504
"""
504-
values, mask, dtype, _ = _get_values(values, skipna, fill_value_typ='+inf',
505-
isfinite=True)
505+
values, mask, dtype, _ = _get_values(values, skipna, fill_value_typ='+inf')
506506
result = values.argmin(axis)
507507
result = _maybe_arg_null_out(result, axis, mask, skipna)
508508
return result

pandas/tests/groupby/test_groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2339,7 +2339,7 @@ def test_non_cython_api(self):
23392339
assert_frame_equal(result, expected)
23402340

23412341
# idxmax
2342-
expected = DataFrame([[0], [nan]], columns=['B'], index=[1, 3])
2342+
expected = DataFrame([[0.0], [nan]], columns=['B'], index=[1, 3])
23432343
expected.index.name = 'A'
23442344
result = g.idxmax()
23452345
assert_frame_equal(result, expected)

pandas/tests/series/test_operators.py

+47
Original file line numberDiff line numberDiff line change
@@ -1857,3 +1857,50 @@ def test_op_duplicate_index(self):
18571857
result = s1 + s2
18581858
expected = pd.Series([11, 12, np.nan], index=[1, 1, 2])
18591859
assert_series_equal(result, expected)
1860+
1861+
@pytest.mark.parametrize(
1862+
"test_input,error_type",
1863+
[
1864+
(pd.Series([]), ValueError),
1865+
1866+
# For strings, or any Series with dtype 'O'
1867+
(pd.Series(['foo', 'bar', 'baz']), TypeError),
1868+
(pd.Series([(1,), (2,)]), TypeError),
1869+
1870+
# For mixed data types
1871+
(
1872+
pd.Series(['foo', 'foo', 'bar', 'bar', None, np.nan, 'baz']),
1873+
TypeError
1874+
),
1875+
]
1876+
)
1877+
def test_assert_argminmax_raises(self, test_input, error_type):
1878+
"""
1879+
Cases where ``Series.argmax`` and related should raise an exception
1880+
"""
1881+
with pytest.raises(error_type):
1882+
test_input.argmin()
1883+
with pytest.raises(error_type):
1884+
test_input.argmin(skipna=False)
1885+
with pytest.raises(error_type):
1886+
test_input.argmax()
1887+
with pytest.raises(error_type):
1888+
test_input.argmax(skipna=False)
1889+
1890+
def test_argminmax_with_inf(self):
1891+
# For numeric data with NA and Inf (GH #13595)
1892+
s = pd.Series([0, -np.inf, np.inf, np.nan])
1893+
1894+
assert s.argmin() == 1
1895+
assert np.isnan(s.argmin(skipna=False))
1896+
1897+
assert s.argmax() == 2
1898+
assert np.isnan(s.argmax(skipna=False))
1899+
1900+
# Using old-style behavior that treats floating point nan, -inf, and
1901+
# +inf as missing
1902+
with pd.option_context('mode.use_inf_as_na', True):
1903+
assert s.argmin() == 0
1904+
assert np.isnan(s.argmin(skipna=False))
1905+
assert s.argmax() == 0
1906+
np.isnan(s.argmax(skipna=False))

0 commit comments

Comments
 (0)