@@ -95,6 +95,7 @@ def __init__(
9595 self .win_freq = None
9696 self .axis = obj ._get_axis_number (axis ) if axis is not None else None
9797 self .validate ()
98+ self ._apply_func_cache = dict ()
9899
99100 @property
100101 def _constructor (self ):
@@ -493,7 +494,7 @@ def _apply(
493494 minimum_periods = _check_min_periods (
494495 self .min_periods or 1 , self .min_periods , len (values ) + offset
495496 )
496- func = partial ( # type: ignore
497+ func_partial = partial ( # type: ignore
497498 func , begin = start , end = end , minimum_periods = minimum_periods
498499 )
499500
@@ -511,7 +512,7 @@ def _apply(
511512 cfunc , check_minp , index_as_array , ** kwargs
512513 )
513514
514- func = partial ( # type: ignore
515+ func_partial = partial ( # type: ignore
515516 func ,
516517 window = window ,
517518 min_periods = self .min_periods ,
@@ -521,12 +522,12 @@ def _apply(
521522 if additional_nans is not None :
522523
523524 def calc (x ):
524- return func (np .concatenate ((x , additional_nans )))
525+ return func_partial (np .concatenate ((x , additional_nans )))
525526
526527 else :
527528
528529 def calc (x ):
529- return func (x )
530+ return func_partial (x )
530531
531532 with np .errstate (all = "ignore" ):
532533 if values .ndim > 1 :
@@ -535,6 +536,9 @@ def calc(x):
535536 result = calc (values )
536537 result = np .asarray (result )
537538
539+ if use_numba :
540+ self ._apply_func_cache [name ] = func
541+
538542 if center :
539543 result = self ._center_window (result , window )
540544
@@ -1147,8 +1151,34 @@ def f(arg, window, min_periods, closed):
11471151
11481152 # Numba doesn't support kwargs in nopython mode
11491153 # https://github.com/numba/numba/issues/2916
1150- numba_func = numba .njit (func )
1151- rolling_apply = partial (methods .rolling_apply , numba_func = numba_func , args = args )
1154+ if func not in self ._apply_func_cache :
1155+
1156+ def make_rolling_apply (func ):
1157+
1158+ numba_func = numba .njit (func )
1159+
1160+ @numba .njit
1161+ def roll_apply (
1162+ values : np .ndarray ,
1163+ begin : np .ndarray ,
1164+ end : np .ndarray ,
1165+ minimum_periods : int ,
1166+ ):
1167+ result = np .empty (len (begin ))
1168+ for i , (start , stop ) in enumerate (zip (begin , end )):
1169+ window = values [start :stop ]
1170+ count_nan = np .sum (np .isnan (window ))
1171+ if len (window ) - count_nan >= minimum_periods :
1172+ result [i ] = numba_func (window , * args )
1173+ else :
1174+ result [i ] = np .nan
1175+ return result
1176+
1177+ return roll_apply
1178+
1179+ rolling_apply = make_rolling_apply (func )
1180+ else :
1181+ rolling_apply = self ._apply_func_cache [func ]
11521182
11531183 return self ._apply (
11541184 rolling_apply ,
0 commit comments