1- from functools import partial
1+ """
2+ Implementation of the rolling aggregations using jitclasses.
3+
4+ Some current difficulties as of numba 0.45.1:
5+
6+ 1) jitclasses don't support inheritance, i.e. a base jitclass cannot be subclassed.
7+
8+ 2) This implementation is not currently utilized because of
9+ inherent performance penalties.
10+ See https://github.com/numba/numba/issues/4522
11+ """
12+
213from typing import Optional
314
15+ import numba
416import numpy as np
517
618from pandas ._typing import Scalar
@@ -56,13 +68,15 @@ class AggKernel:
5668 make_aggregator
5769 """
5870
71+ def __init__ (self ):
72+ pass
73+
5974 def finalize (self ):
6075 """Return the final value of the aggregation."""
6176 raise NotImplementedError
6277
63- @classmethod
6478 def make_aggregator (
65- cls , values : np .ndarray , minimum_periods : int
79+ self , values : np .ndarray , minimum_periods : int
6680 ) -> BaseAggregator :
6781 """Return an aggregator that performs the aggregation calculation"""
6882 raise NotImplementedError
@@ -80,14 +94,30 @@ def invert(self, value) -> None:
8094 raise NotImplementedError
8195
8296
97+ agg_type = numba .deferred_type ()
98+
99+
100+ base_aggregator_spec = (
101+ ("values" , numba .float64 [:]),
102+ ("min_periods" , numba .uint64 ),
103+ ("agg" , agg_type ),
104+ ("previous_start" , numba .int64 ),
105+ ("previous_end" , numba .int64 ),
106+ )
107+
108+
109+ @numba .jitclass (base_aggregator_spec )
83110class SubtractableAggregator (BaseAggregator ):
84111 """
85112 Aggregator in which a current aggregated value
86113 is offset from a prior aggregated value.
87114 """
88115
89116 def __init__ (self , values : np .ndarray , min_periods : int , agg ) -> None :
90- super ().__init__ (values , min_periods )
117+ # Note: Numba doesn't like inheritance
118+ # super().__init__(values, min_periods)
119+ self .values = values
120+ self .min_periods = min_periods
91121 self .agg = agg
92122 self .previous_start = - 1
93123 self .previous_end = - 1
@@ -108,7 +138,8 @@ def query(self, start: int, stop: int) -> Optional[Scalar]:
108138 self .previous_end = stop
109139 if self .agg .count >= self .min_periods :
110140 return self .agg .finalize ()
111- return None
141+ # Numba wanted this to be None instead of None
142+ return np .nan
112143
113144
114145class Sum (UnaryAggKernel ):
@@ -140,32 +171,40 @@ def combine(self, other) -> None:
140171 self .total += other .total
141172 self .count += other .count
142173
143- @classmethod
144- def make_aggregator (cls , values : np .ndarray , min_periods : int ) -> BaseAggregator :
145- aggregator = SubtractableAggregator (values , min_periods , cls ())
174+ def make_aggregator (self , values : np .ndarray , min_periods : int ) -> BaseAggregator :
175+ aggregator = SubtractableAggregator (values , min_periods , self )
146176 return aggregator
147177
148178
179+ sum_spec = (("count" , numba .uint64 ), ("total" , numba .float64 ))
180+
181+
182+ @numba .jitclass (sum_spec )
149183class Mean (Sum ):
150184 def finalize (self ) -> Optional [float ]:
151185 if not self .count :
152186 return None
153187 return self .total / self .count
154188
155189
156- def rolling_aggregation (
190+ agg_type .define (Mean .class_type .instance_type ) # type: ignore
191+
192+
193+ aggregation_signature = (numba .float64 [:], numba .int64 [:], numba .int64 [:], numba .int64 )
194+
195+
196+ @numba .njit (aggregation_signature , nogil = True , parallel = True )
197+ def rolling_mean (
157198 values : np .ndarray ,
158199 begin : np .ndarray ,
159200 end : np .ndarray ,
160201 minimum_periods : int ,
161- kernel_class ,
202+ # kernel_class, Don't think I can define this in the signature in nopython mode
162203) -> np .ndarray :
163204 """Perform a generic rolling aggregation"""
164- aggregator = kernel_class .make_aggregator (values , minimum_periods )
205+ aggregator = Mean ().make_aggregator (values , minimum_periods )
206+ # aggregator = kernel_class().make_aggregator(values, minimum_periods)
165207 result = np .empty (len (begin ))
166208 for i , (start , stop ) in enumerate (zip (begin , end )):
167209 result [i ] = aggregator .query (start , stop )
168210 return result
169-
170-
171- rolling_mean = partial (rolling_aggregation , kernel_class = Mean )
0 commit comments