Skip to content

Commit

Permalink
BUG: groupby.transform/agg with engine='numba' and a MultiIndex (pand…
Browse files Browse the repository at this point in the history
…as-dev#47057)

Co-authored-by: Jeff Reback <jeff@reback.net>
  • Loading branch information
2 people authored and yehoshuadimarsky committed Jul 13, 2022
1 parent 9f9ba46 commit bd5cb28
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.4.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Fixed regressions
~~~~~~~~~~~~~~~~~
- Fixed regression in :meth:`DataFrame.nsmallest` led to wrong results when ``np.nan`` in the sorting column (:issue:`46589`)
- Fixed regression in :func:`read_fwf` raising ``ValueError`` when ``widths`` was specified with ``usecols`` (:issue:`46580`)
- Fixed regression in :meth:`.Groupby.transform` and :meth:`.Groupby.agg` failing with ``engine="numba"`` when the index was a :class:`MultiIndex` (:issue:`46867`)
- Fixed regression is :meth:`.Styler.to_latex` and :meth:`.Styler.to_html` where ``buf`` failed in combination with ``encoding`` (:issue:`47053`)

.. ---------------------------------------------------------------------------
Expand Down
11 changes: 10 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,16 @@ def _numba_prep(self, data):
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)

sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
sorted_index_data = data.index.take(sorted_index).to_numpy()
if len(self.grouper.groupings) > 1:
raise NotImplementedError(
"More than 1 grouping labels are not supported with engine='numba'"
)
# GH 46867
index_data = data.index
if isinstance(index_data, MultiIndex):
group_key = self.grouper.groupings[0].name
index_data = index_data.get_level_values(group_key)
sorted_index_data = index_data.take(sorted_index).to_numpy()

starts, ends = lib.generate_slices(sorted_ids, ngroups)
return (
Expand Down
27 changes: 27 additions & 0 deletions pandas/tests/groupby/aggregate/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,30 @@ def func_kwargs(values, index):
)
expected = DataFrame({"value": [1.0, 1.0, 1.0]})
tm.assert_frame_equal(result, expected)


@td.skip_if_no("numba")
def test_multiindex_one_key(nogil, parallel, nopython):
def numba_func(values, index):
return 1

df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
result = df.groupby("A").agg(
numba_func, engine="numba", engine_kwargs=engine_kwargs
)
expected = DataFrame([1.0], index=Index([1], name="A"), columns=["C"])
tm.assert_frame_equal(result, expected)


@td.skip_if_no("numba")
def test_multiindex_multi_key_not_supported(nogil, parallel, nopython):
def numba_func(values, index):
return 1

df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
df.groupby(["A", "B"]).agg(
numba_func, engine="numba", engine_kwargs=engine_kwargs
)
27 changes: 27 additions & 0 deletions pandas/tests/groupby/transform/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,30 @@ def func_kwargs(values, index):
)
expected = DataFrame({"value": [1.0, 1.0, 1.0]})
tm.assert_frame_equal(result, expected)


@td.skip_if_no("numba")
def test_multiindex_one_key(nogil, parallel, nopython):
def numba_func(values, index):
return 1

df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
result = df.groupby("A").transform(
numba_func, engine="numba", engine_kwargs=engine_kwargs
)
expected = DataFrame([{"A": 1, "B": 2, "C": 1.0}]).set_index(["A", "B"])
tm.assert_frame_equal(result, expected)


@td.skip_if_no("numba")
def test_multiindex_multi_key_not_supported(nogil, parallel, nopython):
def numba_func(values, index):
return 1

df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
df.groupby(["A", "B"]).transform(
numba_func, engine="numba", engine_kwargs=engine_kwargs
)

0 comments on commit bd5cb28

Please sign in to comment.