Skip to content

Commit

Permalink
ENH: Numba groupby support multiple labels (#53556)
Browse files Browse the repository at this point in the history
* ENH: Numba groupby support multiple labels

* update regex
  • Loading branch information
lithomas1 authored Jun 7, 2023
1 parent eca28a3 commit 4e83066
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 6 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Other enhancements
- :meth:`Categorical.from_codes` has gotten a ``validate`` parameter (:issue:`50975`)
- :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
- :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
- :meth:`DataFrameGroupby.agg` and :meth:`DataFrameGroupby.transform` now support grouping by multiple keys when the index is not a :class:`MultiIndex` for ``engine="numba"`` (:issue:`53486`)
- :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`)
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
- Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`).
Expand Down
9 changes: 5 additions & 4 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,13 +1453,14 @@ def _numba_prep(self, data: DataFrame):
sorted_ids = self.grouper._sorted_ids

sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
if len(self.grouper.groupings) > 1:
raise NotImplementedError(
"More than 1 grouping labels are not supported with engine='numba'"
)
# GH 46867
index_data = data.index
if isinstance(index_data, MultiIndex):
if len(self.grouper.groupings) > 1:
raise NotImplementedError(
"Grouping with more than 1 grouping labels and "
"a MultiIndex is not supported with engine='numba'"
)
group_key = self.grouper.groupings[0].name
index_data = index_data.get_level_values(group_key)
sorted_index_data = index_data.take(sorted_index).to_numpy()
Expand Down
39 changes: 38 additions & 1 deletion pandas/tests/groupby/aggregate/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,44 @@ def numba_func(values, index):

df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
with pytest.raises(NotImplementedError, match="more than 1 grouping labels"):
df.groupby(["A", "B"]).agg(
numba_func, engine="numba", engine_kwargs=engine_kwargs
)


@td.skip_if_no("numba")
def test_multilabel_numba_vs_cython(numba_supported_reductions):
reduction, kwargs = numba_supported_reductions
df = DataFrame(
{
"A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
"B": ["one", "one", "two", "three", "two", "two", "one", "three"],
"C": np.random.randn(8),
"D": np.random.randn(8),
}
)
gb = df.groupby(["A", "B"])
res_agg = gb.agg(reduction, engine="numba", **kwargs)
expected_agg = gb.agg(reduction, engine="cython", **kwargs)
tm.assert_frame_equal(res_agg, expected_agg)
# Test that calling the aggregation directly also works
direct_res = getattr(gb, reduction)(engine="numba", **kwargs)
direct_expected = getattr(gb, reduction)(engine="cython", **kwargs)
tm.assert_frame_equal(direct_res, direct_expected)


@td.skip_if_no("numba")
def test_multilabel_udf_numba_vs_cython():
df = DataFrame(
{
"A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
"B": ["one", "one", "two", "three", "two", "two", "one", "three"],
"C": np.random.randn(8),
"D": np.random.randn(8),
}
)
gb = df.groupby(["A", "B"])
result = gb.agg(lambda values, index: values.min(), engine="numba")
expected = gb.agg(lambda x: x.min(), engine="cython")
tm.assert_frame_equal(result, expected)
44 changes: 43 additions & 1 deletion pandas/tests/groupby/transform/test_numba.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest

from pandas.errors import NumbaUtilError
Expand Down Expand Up @@ -224,7 +225,48 @@ def numba_func(values, index):

df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
with pytest.raises(NotImplementedError, match="more than 1 grouping labels"):
df.groupby(["A", "B"]).transform(
numba_func, engine="numba", engine_kwargs=engine_kwargs
)


@td.skip_if_no("numba")
@pytest.mark.xfail(
reason="Groupby transform doesn't support strings as function inputs yet with numba"
)
def test_multilabel_numba_vs_cython(numba_supported_reductions):
reduction, kwargs = numba_supported_reductions
df = DataFrame(
{
"A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
"B": ["one", "one", "two", "three", "two", "two", "one", "three"],
"C": np.random.randn(8),
"D": np.random.randn(8),
}
)
gb = df.groupby(["A", "B"])
res_agg = gb.transform(reduction, engine="numba", **kwargs)
expected_agg = gb.transform(reduction, engine="cython", **kwargs)
tm.assert_frame_equal(res_agg, expected_agg)


@td.skip_if_no("numba")
def test_multilabel_udf_numba_vs_cython():
df = DataFrame(
{
"A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
"B": ["one", "one", "two", "three", "two", "two", "one", "three"],
"C": np.random.randn(8),
"D": np.random.randn(8),
}
)
gb = df.groupby(["A", "B"])
result = gb.transform(
lambda values, index: (values - values.min()) / (values.max() - values.min()),
engine="numba",
)
expected = gb.transform(
lambda x: (x - x.min()) / (x.max() - x.min()), engine="cython"
)
tm.assert_frame_equal(result, expected)

0 comments on commit 4e83066

Please sign in to comment.