Skip to content

Commit

Permalink
Backport PR #43172: BUG: Pass index data correctly in groupby.transfo…
Browse files Browse the repository at this point in the history
…rm/agg w/ engine=numba (#43250)

Co-authored-by: Matthew Roeschke <emailformattr@gmail.com>
  • Loading branch information
meeseeksmachine and mroeschke authored Aug 27, 2021
1 parent 5c2c116 commit 5512bc4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 2 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.3.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Fixed regressions

Bug fixes
~~~~~~~~~
-
- Bug in :meth:`.DataFrameGroupBy.agg` and :meth:`.DataFrameGroupBy.transform` with ``engine="numba"`` where ``index`` data was not being correctly passed into ``func`` (:issue:`43133`)
-

.. ---------------------------------------------------------------------------
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,9 +1143,15 @@ def _numba_prep(self, func, data):
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)

sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
sorted_index_data = data.index.take(sorted_index).to_numpy()

starts, ends = lib.generate_slices(sorted_ids, ngroups)
return starts, ends, sorted_index, sorted_data
return (
starts,
ends,
sorted_index_data,
sorted_data,
)

@final
def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
Expand Down
14 changes: 14 additions & 0 deletions pandas/tests/groupby/aggregate/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,17 @@ def sum_last(values, index, n):
result = grouped_x.agg(sum_last, 2, engine="numba")
expected = Series([2.0] * 2, name="x", index=Index([0, 1], name="id"))
tm.assert_series_equal(result, expected)


@td.skip_if_no("numba", "0.46.0")
def test_index_data_correctly_passed():
# GH 43133
def f(values, index):
return np.mean(index)

df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3])
result = df.groupby("group").aggregate(f, engine="numba")
expected = DataFrame(
[-1.5, -3.0], columns=["v"], index=Index(["A", "B"], name="group")
)
tm.assert_frame_equal(result, expected)
12 changes: 12 additions & 0 deletions pandas/tests/groupby/transform/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,15 @@ def sum_last(values, index, n):
result = grouped_x.transform(sum_last, 2, engine="numba")
expected = Series([2.0] * 4, name="x")
tm.assert_series_equal(result, expected)


@td.skip_if_no("numba", "0.46.0")
def test_index_data_correctly_passed():
# GH 43133
def f(values, index):
return index - 1

df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3])
result = df.groupby("group").transform(f, engine="numba")
expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3])
tm.assert_frame_equal(result, expected)

0 comments on commit 5512bc4

Please sign in to comment.