Skip to content

Commit 3b26841

Browse files
authored
BUG: idxmin & idxmax axis = 1 str reducer for transform (#50329)
1 parent 3a0db10 commit 3b26841

File tree

3 files changed

+143
-15
lines changed

3 files changed

+143
-15
lines changed

doc/source/whatsnew/v2.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,7 @@ Groupby/resample/rolling
955955
- Bug in :meth:`DataFrame.groupby` would not include a :class:`.Grouper` specified by ``key`` in the result when ``as_index=False`` (:issue:`50413`)
956956
- Bug in :meth:`.DataFrameGrouBy.value_counts` would raise when used with a :class:`.TimeGrouper` (:issue:`50486`)
957957
- Bug in :meth:`Resampler.size` caused a wide :class:`DataFrame` to be returned instead of a :class:`Series` with :class:`MultiIndex` (:issue:`46826`)
958+
- Bug in :meth:`.DataFrameGroupBy.transform` and :meth:`.SeriesGroupBy.transform` would raise incorrectly when grouper had ``axis=1`` for ``"idxmin"`` and ``"idxmax"`` arguments (:issue:`45986`)
958959
-
959960

960961
Reshaping

pandas/core/groupby/generic.py

Lines changed: 142 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@
9595
)
9696
from pandas.core.indexes.category import CategoricalIndex
9797
from pandas.core.series import Series
98-
from pandas.core.shared_docs import _shared_docs
9998
from pandas.core.util.numba_ import maybe_use_numba
10099

101100
from pandas.plotting import boxplot_frame_groupby
@@ -1848,17 +1847,82 @@ def nunique(self, dropna: bool = True) -> DataFrame:
18481847

18491848
return results
18501849

1851-
@doc(
1852-
_shared_docs["idxmax"],
1853-
numeric_only_default="False",
1854-
)
18551850
def idxmax(
18561851
self,
1857-
axis: Axis = 0,
1852+
axis: Axis | None = None,
18581853
skipna: bool = True,
18591854
numeric_only: bool = False,
18601855
) -> DataFrame:
1861-
axis = DataFrame._get_axis_number(axis)
1856+
"""
1857+
Return index of first occurrence of maximum over requested axis.
1858+
1859+
NA/null values are excluded.
1860+
1861+
Parameters
1862+
----------
1863+
axis : {{0 or 'index', 1 or 'columns'}}, default None
1864+
The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise.
1865+
If axis is not provided, grouper's axis is used.
1866+
1867+
.. versionchanged:: 2.0.0
1868+
1869+
skipna : bool, default True
1870+
Exclude NA/null values. If an entire row/column is NA, the result
1871+
will be NA.
1872+
numeric_only : bool, default False
1873+
Include only `float`, `int` or `boolean` data.
1874+
1875+
.. versionadded:: 1.5.0
1876+
1877+
Returns
1878+
-------
1879+
Series
1880+
Indexes of maxima along the specified axis.
1881+
1882+
Raises
1883+
------
1884+
ValueError
1885+
* If the row/column is empty
1886+
1887+
See Also
1888+
--------
1889+
Series.idxmax : Return index of the maximum element.
1890+
1891+
Notes
1892+
-----
1893+
This method is the DataFrame version of ``ndarray.argmax``.
1894+
1895+
Examples
1896+
--------
1897+
Consider a dataset containing food consumption in Argentina.
1898+
1899+
>>> df = pd.DataFrame({'consumption': [10.51, 103.11, 55.48],
1900+
... 'co2_emissions': [37.2, 19.66, 1712]},
1901+
... index=['Pork', 'Wheat Products', 'Beef'])
1902+
1903+
>>> df
1904+
consumption co2_emissions
1905+
Pork 10.51 37.20
1906+
Wheat Products 103.11 19.66
1907+
Beef 55.48 1712.00
1908+
1909+
By default, it returns the index for the maximum value in each column.
1910+
1911+
>>> df.idxmax()
1912+
consumption Wheat Products
1913+
co2_emissions Beef
1914+
dtype: object
1915+
1916+
To return the index for the maximum value in each row, use ``axis="columns"``.
1917+
1918+
>>> df.idxmax(axis="columns")
1919+
Pork co2_emissions
1920+
Wheat Products consumption
1921+
Beef co2_emissions
1922+
dtype: object
1923+
"""
1924+
if axis is None:
1925+
axis = self.axis
18621926

18631927
def func(df):
18641928
res = df._reduce(
@@ -1879,17 +1943,82 @@ def func(df):
18791943
)
18801944
return result
18811945

1882-
@doc(
1883-
_shared_docs["idxmin"],
1884-
numeric_only_default="False",
1885-
)
18861946
def idxmin(
18871947
self,
1888-
axis: Axis = 0,
1948+
axis: Axis | None = None,
18891949
skipna: bool = True,
18901950
numeric_only: bool = False,
18911951
) -> DataFrame:
1892-
axis = DataFrame._get_axis_number(axis)
1952+
"""
1953+
Return index of first occurrence of minimum over requested axis.
1954+
1955+
NA/null values are excluded.
1956+
1957+
Parameters
1958+
----------
1959+
axis : {{0 or 'index', 1 or 'columns'}}, default None
1960+
The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise.
1961+
If axis is not provided, grouper's axis is used.
1962+
1963+
.. versionchanged:: 2.0.0
1964+
1965+
skipna : bool, default True
1966+
Exclude NA/null values. If an entire row/column is NA, the result
1967+
will be NA.
1968+
numeric_only : bool, default False
1969+
Include only `float`, `int` or `boolean` data.
1970+
1971+
.. versionadded:: 1.5.0
1972+
1973+
Returns
1974+
-------
1975+
Series
1976+
Indexes of minima along the specified axis.
1977+
1978+
Raises
1979+
------
1980+
ValueError
1981+
* If the row/column is empty
1982+
1983+
See Also
1984+
--------
1985+
Series.idxmin : Return index of the minimum element.
1986+
1987+
Notes
1988+
-----
1989+
This method is the DataFrame version of ``ndarray.argmin``.
1990+
1991+
Examples
1992+
--------
1993+
Consider a dataset containing food consumption in Argentina.
1994+
1995+
>>> df = pd.DataFrame({'consumption': [10.51, 103.11, 55.48],
1996+
... 'co2_emissions': [37.2, 19.66, 1712]},
1997+
... index=['Pork', 'Wheat Products', 'Beef'])
1998+
1999+
>>> df
2000+
consumption co2_emissions
2001+
Pork 10.51 37.20
2002+
Wheat Products 103.11 19.66
2003+
Beef 55.48 1712.00
2004+
2005+
By default, it returns the index for the minimum value in each column.
2006+
2007+
>>> df.idxmin()
2008+
consumption Pork
2009+
co2_emissions Wheat Products
2010+
dtype: object
2011+
2012+
To return the index for the minimum value in each row, use ``axis="columns"``.
2013+
2014+
>>> df.idxmin(axis="columns")
2015+
Pork consumption
2016+
Wheat Products co2_emissions
2017+
Beef consumption
2018+
dtype: object
2019+
"""
2020+
if axis is None:
2021+
axis = self.axis
18932022

18942023
def func(df):
18952024
res = df._reduce(

pandas/tests/groupby/transform/test_transform.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,6 @@ def test_transform_axis_1_reducer(request, reduction_func):
185185
# GH#45715
186186
if reduction_func in (
187187
"corrwith",
188-
"idxmax",
189-
"idxmin",
190188
"ngroup",
191189
"nth",
192190
):

0 commit comments

Comments
 (0)