|
14 | 14 | qcut,
|
15 | 15 | )
|
16 | 16 | import pandas._testing as tm
|
| 17 | +from pandas.core.groupby.generic import SeriesGroupBy |
17 | 18 | from pandas.tests.groupby import get_groupby_method_args
|
18 | 19 |
|
19 | 20 |
|
@@ -2007,3 +2008,50 @@ def test_many_categories(as_index, sort, index_kind, ordered):
|
2007 | 2008 | expected = DataFrame({"a": Series(index), "b": data})
|
2008 | 2009 |
|
2009 | 2010 | tm.assert_frame_equal(result, expected)
|
| 2011 | + |
| 2012 | + |
| 2013 | +@pytest.mark.parametrize("test_series", [True, False]) |
| 2014 | +@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]]) |
| 2015 | +def test_agg_list(request, as_index, observed, reduction_func, test_series, keys): |
| 2016 | + # GH#52760 |
| 2017 | + if test_series and reduction_func == "corrwith": |
| 2018 | + assert not hasattr(SeriesGroupBy, "corrwith") |
| 2019 | + pytest.skip("corrwith not implemented for SeriesGroupBy") |
| 2020 | + elif reduction_func == "corrwith": |
| 2021 | + msg = "GH#32293: attempts to call SeriesGroupBy.corrwith" |
| 2022 | + request.node.add_marker(pytest.mark.xfail(reason=msg)) |
| 2023 | + elif ( |
| 2024 | + reduction_func == "nunique" |
| 2025 | + and not test_series |
| 2026 | + and len(keys) != 1 |
| 2027 | + and not observed |
| 2028 | + and not as_index |
| 2029 | + ): |
| 2030 | + msg = "GH#52848 - raises a ValueError" |
| 2031 | + request.node.add_marker(pytest.mark.xfail(reason=msg)) |
| 2032 | + |
| 2033 | + df = DataFrame({"a1": [0, 0, 1], "a2": [2, 3, 3], "b": [4, 5, 6]}) |
| 2034 | + df = df.astype({"a1": "category", "a2": "category"}) |
| 2035 | + if "a2" not in keys: |
| 2036 | + df = df.drop(columns="a2") |
| 2037 | + gb = df.groupby(by=keys, as_index=as_index, observed=observed) |
| 2038 | + if test_series: |
| 2039 | + gb = gb["b"] |
| 2040 | + args = get_groupby_method_args(reduction_func, df) |
| 2041 | + |
| 2042 | + result = gb.agg([reduction_func], *args) |
| 2043 | + expected = getattr(gb, reduction_func)(*args) |
| 2044 | + |
| 2045 | + if as_index and (test_series or reduction_func == "size"): |
| 2046 | + expected = expected.to_frame(reduction_func) |
| 2047 | + if not test_series: |
| 2048 | + if not as_index: |
| 2049 | + # TODO: GH#52849 - as_index=False is not respected |
| 2050 | + expected = expected.set_index(keys) |
| 2051 | + expected.columns = MultiIndex( |
| 2052 | + levels=[["b"], [reduction_func]], codes=[[0], [0]] |
| 2053 | + ) |
| 2054 | + elif not as_index: |
| 2055 | + expected.columns = keys + [reduction_func] |
| 2056 | + |
| 2057 | + tm.assert_equal(result, expected) |
0 commit comments