Skip to content

Commit 5f476d9

Browse files
author
Matt Roeschke
committed
WIP: Add Numba to rolling.apply
1 parent 1bff9e1 commit 5f476d9

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

pandas/core/window/aggregators/methods.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Callable
2+
3+
import numba
14
import numpy as np
25

36

@@ -33,3 +36,23 @@ def rolling_mean(
3336
val = total / count
3437
result[i] = val
3538
return result
39+
40+
41+
@numba.njit(nogil=True)
42+
def rolling_apply(
43+
values: np.ndarray,
44+
begin: np.ndarray,
45+
end: np.ndarray,
46+
minimum_periods: int,
47+
numba_func: Callable,
48+
args,
49+
) -> np.ndarray:
50+
result = np.empty(len(begin))
51+
for i, (start, stop) in enumerate(zip(begin, end)):
52+
window = values[start:stop]
53+
count_nan = np.sum(np.isnan(window))
54+
if len(window) - count_nan >= minimum_periods:
55+
result[i] = numba_func(window, *args)
56+
else:
57+
result[i] = np.nan
58+
return result

pandas/core/window/rolling.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Callable, List, Optional, Set, Union
99
import warnings
1010

11+
import numba
1112
import numpy as np
1213

1314
import pandas._libs.window as libwindow
@@ -1144,7 +1145,20 @@ def f(arg, window, min_periods, closed):
11441145
kwargs,
11451146
)
11461147

1147-
return self._apply(f, func, args=args, kwargs=kwargs, center=False, raw=raw)
1148+
# Numba doesn't support kwargs in nopython mode
1149+
# https://github.com/numba/numba/issues/2916
1150+
numba_func = numba.njit(func)
1151+
rolling_apply = partial(methods.rolling_apply, numba_func=numba_func, args=args)
1152+
1153+
return self._apply(
1154+
rolling_apply,
1155+
func,
1156+
args=args,
1157+
kwargs=kwargs,
1158+
center=False,
1159+
raw=raw,
1160+
use_numba=True,
1161+
)
11481162

11491163
def sum(self, *args, **kwargs):
11501164
nv.validate_window_func("sum", args, kwargs)

0 commit comments

Comments
 (0)