Skip to content

Commit 4628665

Browse files
authored
TST: GroupBy(..., as_index=True).agg() drops index (#33098)
1 parent 594dc2a commit 4628665

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

pandas/tests/groupby/aggregate/test_aggregate.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,17 @@ def test_aggregate_mixed_types():
809809
tm.assert_frame_equal(result, expected)
810810

811811

812-
@pytest.mark.xfail(reason="Not implemented.")
812+
@pytest.mark.parametrize("func", ["min", "max"])
813+
def test_aggregate_categorical_lost_index(func: str):
814+
# GH: 28641 groupby drops index, when grouping over categorical column with min/max
815+
ds = pd.Series(["b"], dtype="category").cat.as_ordered()
816+
df = pd.DataFrame({"A": [1997], "B": ds})
817+
result = df.groupby("A").agg({"B": func})
818+
expected = pd.DataFrame({"B": ["b"]}, index=pd.Index([1997], name="A"))
819+
tm.assert_frame_equal(result, expected)
820+
821+
822+
@pytest.mark.xfail(reason="Not implemented;see GH 31256")
813823
def test_aggregate_udf_na_extension_type():
814824
# https://github.com/pandas-dev/pandas/pull/31359
815825
# This is currently failing to cast back to Int64Dtype.

pandas/tests/groupby/test_categorical.py

+13
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,19 @@ def test_groupby_agg_non_numeric():
13881388
tm.assert_frame_equal(result, expected)
13891389

13901390

1391+
@pytest.mark.parametrize("func", ["first", "last"])
1392+
def test_groupy_first_returned_categorical_instead_of_dataframe(func):
1393+
# GH 28641: groupby drops index, when grouping over categorical column with
1394+
# first/last. Renamed Categorical instead of DataFrame previously.
1395+
df = pd.DataFrame(
1396+
{"A": [1997], "B": pd.Series(["b"], dtype="category").cat.as_ordered()}
1397+
)
1398+
df_grouped = df.groupby("A")["B"]
1399+
result = getattr(df_grouped, func)()
1400+
expected = pd.Series(["b"], index=pd.Index([1997], name="A"), name="B")
1401+
tm.assert_series_equal(result, expected)
1402+
1403+
13911404
def test_read_only_category_no_sort():
13921405
# GH33410
13931406
cats = np.array([1, 2])

0 commit comments

Comments
 (0)