Skip to content

Commit e0547d1

Browse files
authored
ENH: Support groupby.ewm operations (#37878)
1 parent 9b38a01 commit e0547d1

File tree

15 files changed

+399
-64
lines changed

15 files changed

+399
-64
lines changed

asv_bench/benchmarks/rolling.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,4 +225,17 @@ def time_rolling_offset(self, method):
225225
getattr(self.groupby_roll_offset, method)()
226226

227227

228+
class GroupbyEWM:
229+
230+
params = ["cython", "numba"]
231+
param_names = ["engine"]
232+
233+
def setup(self, engine):
234+
df = pd.DataFrame({"A": range(50), "B": range(50)})
235+
self.gb_ewm = df.groupby("A").ewm(com=1.0)
236+
237+
def time_groupby_mean(self, engine):
238+
self.gb_ewm.mean(engine=engine)
239+
240+
228241
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/user_guide/window.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Concept Method Returned Object
4343
Rolling window ``rolling`` ``Rolling`` Yes Yes
4444
Weighted window ``rolling`` ``Window`` No No
4545
Expanding window ``expanding`` ``Expanding`` No Yes
46-
Exponentially Weighted window ``ewm`` ``ExponentialMovingWindow`` No No
46+
Exponentially Weighted window ``ewm`` ``ExponentialMovingWindow`` No Yes (as of version 1.2)
4747
============================= ================= =========================== =========================== ========================
4848

4949
As noted above, some operations support specifying a window based on a time offset:

doc/source/whatsnew/v1.2.0.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,23 @@ example where the index name is preserved:
204204
The same is true for :class:`MultiIndex`, but the logic is applied separately on a
205205
level-by-level basis.
206206

207+
.. _whatsnew_120.groupby_ewm:
208+
209+
Groupby supports EWM operations directly
210+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
211+
212+
:class:`DataFrameGroupBy` now supports exponentially weighted window operations directly (:issue:`16037`).
213+
214+
.. ipython:: python
215+
216+
df = pd.DataFrame({'A': ['a', 'b', 'a', 'b'], 'B': range(4)})
217+
df
218+
df.groupby('A').ewm(com=1.0).mean()
219+
220+
Additionally ``mean`` supports execution via `Numba <https://numba.pydata.org/>`__ with
221+
the ``engine`` and ``engine_kwargs`` arguments. Numba must be installed as an optional dependency
222+
to use this feature.
223+
207224
.. _whatsnew_120.enhancements.other:
208225

209226
Other enhancements

pandas/_libs/window/aggregations.pyx

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,15 +1496,17 @@ def roll_weighted_var(float64_t[:] values, float64_t[:] weights,
14961496
# ----------------------------------------------------------------------
14971497
# Exponentially weighted moving average
14981498

1499-
def ewma_time(const float64_t[:] vals, int minp, ndarray[int64_t] times,
1500-
int64_t halflife):
1499+
def ewma_time(const float64_t[:] vals, int64_t[:] start, int64_t[:] end,
1500+
int minp, ndarray[int64_t] times, int64_t halflife):
15011501
"""
15021502
Compute exponentially-weighted moving average using halflife and time
15031503
distances.
15041504
15051505
Parameters
15061506
----------
15071507
vals : ndarray[float_64]
1508+
start: ndarray[int_64]
1509+
end: ndarray[int_64]
15081510
minp : int
15091511
times : ndarray[int64]
15101512
halflife : int64
@@ -1552,17 +1554,20 @@ def ewma_time(const float64_t[:] vals, int minp, ndarray[int64_t] times,
15521554
return output
15531555

15541556

1555-
def ewma(float64_t[:] vals, float64_t com, bint adjust, bint ignore_na, int minp):
1557+
def ewma(float64_t[:] vals, int64_t[:] start, int64_t[:] end, int minp,
1558+
float64_t com, bint adjust, bint ignore_na):
15561559
"""
15571560
Compute exponentially-weighted moving average using center-of-mass.
15581561
15591562
Parameters
15601563
----------
15611564
vals : ndarray (float64 type)
1565+
start: ndarray (int64 type)
1566+
end: ndarray (int64 type)
1567+
minp : int
15621568
com : float64
15631569
adjust : int
15641570
ignore_na : bool
1565-
minp : int
15661571
15671572
Returns
15681573
-------
@@ -1620,19 +1625,21 @@ def ewma(float64_t[:] vals, float64_t com, bint adjust, bint ignore_na, int minp
16201625
# Exponentially weighted moving covariance
16211626

16221627

1623-
def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
1624-
float64_t com, bint adjust, bint ignore_na, int minp, bint bias):
1628+
def ewmcov(float64_t[:] input_x, int64_t[:] start, int64_t[:] end, int minp,
1629+
float64_t[:] input_y, float64_t com, bint adjust, bint ignore_na, bint bias):
16251630
"""
16261631
Compute exponentially-weighted moving variance using center-of-mass.
16271632
16281633
Parameters
16291634
----------
16301635
input_x : ndarray (float64 type)
1636+
start: ndarray (int64 type)
1637+
end: ndarray (int64 type)
1638+
minp : int
16311639
input_y : ndarray (float64 type)
16321640
com : float64
16331641
adjust : int
16341642
ignore_na : bool
1635-
minp : int
16361643
bias : int
16371644
16381645
Returns

pandas/core/groupby/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def _gotitem(self, key, ndim, subset=None):
192192
"describe",
193193
"dtypes",
194194
"expanding",
195+
"ewm",
195196
"filter",
196197
"get_group",
197198
"groups",

pandas/core/groupby/groupby.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,6 +1859,16 @@ def expanding(self, *args, **kwargs):
18591859

18601860
return ExpandingGroupby(self, *args, **kwargs)
18611861

1862+
@Substitution(name="groupby")
1863+
@Appender(_common_see_also)
1864+
def ewm(self, *args, **kwargs):
1865+
"""
1866+
Return an ewm grouper, providing ewm functionality per group.
1867+
"""
1868+
from pandas.core.window import ExponentialMovingWindowGroupby
1869+
1870+
return ExponentialMovingWindowGroupby(self, *args, **kwargs)
1871+
18621872
def _fill(self, direction, limit=None):
18631873
"""
18641874
Shared function for `pad` and `backfill` to call Cython method.

pandas/core/window/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1-
from pandas.core.window.ewm import ExponentialMovingWindow # noqa:F401
1+
from pandas.core.window.ewm import ( # noqa:F401
2+
ExponentialMovingWindow,
3+
ExponentialMovingWindowGroupby,
4+
)
25
from pandas.core.window.expanding import Expanding, ExpandingGroupby # noqa:F401
36
from pandas.core.window.rolling import Rolling, RollingGroupby, Window # noqa:F401

pandas/core/window/ewm.py

Lines changed: 122 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,20 @@
1414
from pandas.core.dtypes.common import is_datetime64_ns_dtype
1515

1616
import pandas.core.common as common
17-
from pandas.core.window.common import _doc_template, _shared_docs, zsqrt
18-
from pandas.core.window.rolling import BaseWindow, flex_binary_moment
17+
from pandas.core.util.numba_ import maybe_use_numba
18+
from pandas.core.window.common import (
19+
_doc_template,
20+
_shared_docs,
21+
flex_binary_moment,
22+
zsqrt,
23+
)
24+
from pandas.core.window.indexers import (
25+
BaseIndexer,
26+
ExponentialMovingWindowIndexer,
27+
GroupbyIndexer,
28+
)
29+
from pandas.core.window.numba_ import generate_numba_groupby_ewma_func
30+
from pandas.core.window.rolling import BaseWindow, BaseWindowGroupby, dispatch
1931

2032
if TYPE_CHECKING:
2133
from pandas import Series
@@ -219,14 +231,16 @@ def __init__(
219231
ignore_na: bool = False,
220232
axis: int = 0,
221233
times: Optional[Union[str, np.ndarray, FrameOrSeries]] = None,
234+
**kwargs,
222235
):
223-
self.com: Optional[float]
224236
self.obj = obj
225237
self.min_periods = max(int(min_periods), 1)
226238
self.adjust = adjust
227239
self.ignore_na = ignore_na
228240
self.axis = axis
229241
self.on = None
242+
self.center = False
243+
self.closed = None
230244
if times is not None:
231245
if isinstance(times, str):
232246
times = self._selected_obj[times]
@@ -245,7 +259,7 @@ def __init__(
245259
if common.count_not_none(com, span, alpha) > 0:
246260
self.com = get_center_of_mass(com, span, None, alpha)
247261
else:
248-
self.com = None
262+
self.com = 0.0
249263
else:
250264
if halflife is not None and isinstance(halflife, (str, datetime.timedelta)):
251265
raise ValueError(
@@ -260,6 +274,12 @@ def __init__(
260274
def _constructor(self):
261275
return ExponentialMovingWindow
262276

277+
def _get_window_indexer(self) -> BaseIndexer:
278+
"""
279+
Return an indexer class that will compute the window start and end bounds
280+
"""
281+
return ExponentialMovingWindowIndexer()
282+
263283
_agg_see_also_doc = dedent(
264284
"""
265285
See Also
@@ -299,27 +319,6 @@ def aggregate(self, func, *args, **kwargs):
299319

300320
agg = aggregate
301321

302-
def _apply(self, func):
303-
"""
304-
Rolling statistical measure using supplied function. Designed to be
305-
used with passed-in Cython array-based functions.
306-
307-
Parameters
308-
----------
309-
func : str/callable to apply
310-
311-
Returns
312-
-------
313-
y : same type as input argument
314-
"""
315-
316-
def homogeneous_func(values: np.ndarray):
317-
if values.size == 0:
318-
return values.copy()
319-
return np.apply_along_axis(func, self.axis, values)
320-
321-
return self._apply_blockwise(homogeneous_func)
322-
323322
@Substitution(name="ewm", func_name="mean")
324323
@Appender(_doc_template)
325324
def mean(self, *args, **kwargs):
@@ -336,7 +335,6 @@ def mean(self, *args, **kwargs):
336335
window_func = self._get_roll_func("ewma_time")
337336
window_func = partial(
338337
window_func,
339-
minp=self.min_periods,
340338
times=self.times,
341339
halflife=self.halflife,
342340
)
@@ -347,7 +345,6 @@ def mean(self, *args, **kwargs):
347345
com=self.com,
348346
adjust=self.adjust,
349347
ignore_na=self.ignore_na,
350-
minp=self.min_periods,
351348
)
352349
return self._apply(window_func)
353350

@@ -371,13 +368,19 @@ def var(self, bias: bool = False, *args, **kwargs):
371368
Exponential weighted moving variance.
372369
"""
373370
nv.validate_window_func("var", args, kwargs)
371+
window_func = self._get_roll_func("ewmcov")
372+
window_func = partial(
373+
window_func,
374+
com=self.com,
375+
adjust=self.adjust,
376+
ignore_na=self.ignore_na,
377+
bias=bias,
378+
)
374379

375-
def f(arg):
376-
return window_aggregations.ewmcov(
377-
arg, arg, self.com, self.adjust, self.ignore_na, self.min_periods, bias
378-
)
380+
def var_func(values, begin, end, min_periods):
381+
return window_func(values, begin, end, min_periods, values)
379382

380-
return self._apply(f)
383+
return self._apply(var_func)
381384

382385
@Substitution(name="ewm", func_name="cov")
383386
@Appender(_doc_template)
@@ -419,11 +422,13 @@ def _get_cov(X, Y):
419422
Y = self._shallow_copy(Y)
420423
cov = window_aggregations.ewmcov(
421424
X._prep_values(),
425+
np.array([0], dtype=np.int64),
426+
np.array([0], dtype=np.int64),
427+
self.min_periods,
422428
Y._prep_values(),
423429
self.com,
424430
self.adjust,
425431
self.ignore_na,
426-
self.min_periods,
427432
bias,
428433
)
429434
return wrap_result(X, cov)
@@ -470,7 +475,15 @@ def _get_corr(X, Y):
470475

471476
def _cov(x, y):
472477
return window_aggregations.ewmcov(
473-
x, y, self.com, self.adjust, self.ignore_na, self.min_periods, 1
478+
x,
479+
np.array([0], dtype=np.int64),
480+
np.array([0], dtype=np.int64),
481+
self.min_periods,
482+
y,
483+
self.com,
484+
self.adjust,
485+
self.ignore_na,
486+
1,
474487
)
475488

476489
x_values = X._prep_values()
@@ -485,3 +498,78 @@ def _cov(x, y):
485498
return flex_binary_moment(
486499
self._selected_obj, other._selected_obj, _get_corr, pairwise=bool(pairwise)
487500
)
501+
502+
503+
class ExponentialMovingWindowGroupby(BaseWindowGroupby, ExponentialMovingWindow):
504+
"""
505+
Provide an exponential moving window groupby implementation.
506+
"""
507+
508+
def _get_window_indexer(self) -> GroupbyIndexer:
509+
"""
510+
Return an indexer class that will compute the window start and end bounds
511+
512+
Returns
513+
-------
514+
GroupbyIndexer
515+
"""
516+
window_indexer = GroupbyIndexer(
517+
groupby_indicies=self._groupby.indices,
518+
window_indexer=ExponentialMovingWindowIndexer,
519+
)
520+
return window_indexer
521+
522+
var = dispatch("var", bias=False)
523+
std = dispatch("std", bias=False)
524+
cov = dispatch("cov", other=None, pairwise=None, bias=False)
525+
corr = dispatch("corr", other=None, pairwise=None)
526+
527+
def mean(self, engine=None, engine_kwargs=None):
528+
"""
529+
Parameters
530+
----------
531+
engine : str, default None
532+
* ``'cython'`` : Runs mean through C-extensions from cython.
533+
* ``'numba'`` : Runs mean through JIT compiled code from numba.
534+
Only available when ``raw`` is set to ``True``.
535+
* ``None`` : Defaults to ``'cython'`` or globally setting
536+
``compute.use_numba``
537+
538+
.. versionadded:: 1.2.0
539+
540+
engine_kwargs : dict, default None
541+
* For ``'cython'`` engine, there are no accepted ``engine_kwargs``
542+
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil``
543+
and ``parallel`` dictionary keys. The values must either be ``True`` or
544+
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is
545+
``{'nopython': True, 'nogil': False, 'parallel': False}``.
546+
547+
.. versionadded:: 1.2.0
548+
549+
Returns
550+
-------
551+
Series or DataFrame
552+
Return type is determined by the caller.
553+
"""
554+
if maybe_use_numba(engine):
555+
groupby_ewma_func = generate_numba_groupby_ewma_func(
556+
engine_kwargs,
557+
self.com,
558+
self.adjust,
559+
self.ignore_na,
560+
)
561+
return self._apply(
562+
groupby_ewma_func,
563+
numba_cache_key=(lambda x: x, "groupby_ewma"),
564+
)
565+
elif engine in ("cython", None):
566+
if engine_kwargs is not None:
567+
raise ValueError("cython engine does not accept engine_kwargs")
568+
569+
def f(x):
570+
x = self._shallow_copy(x, groupby=self._groupby)
571+
return x.mean()
572+
573+
return self._groupby.apply(f)
574+
else:
575+
raise ValueError("engine must be either 'numba' or 'cython'")

0 commit comments

Comments
 (0)