7272 PeriodArray ,
7373 TimedeltaArray ,
7474)
75- from pandas .core .arrays .boolean import BooleanDtype
76- from pandas .core .arrays .floating import FloatingDtype
77- from pandas .core .arrays .integer import IntegerDtype
7875from pandas .core .arrays .masked import (
7976 BaseMaskedArray ,
8077 BaseMaskedDtype ,
@@ -147,26 +144,6 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
147144 },
148145 }
149146
150- # "group_any" and "group_all" are also support masks, but don't go
151- # through WrappedCythonOp
152- _MASKED_CYTHON_FUNCTIONS = {
153- "cummin" ,
154- "cummax" ,
155- "min" ,
156- "max" ,
157- "last" ,
158- "first" ,
159- "rank" ,
160- "sum" ,
161- "ohlc" ,
162- "cumprod" ,
163- "cumsum" ,
164- "prod" ,
165- "mean" ,
166- "var" ,
167- "median" ,
168- }
169-
170147 _cython_arity = {"ohlc" : 4 } # OHLC
171148
172149 # Note: we make this a classmethod and pass kind+how so that caching
@@ -220,8 +197,8 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
220197 """
221198 how = self .how
222199
223- if how in [ "median" ] :
224- # these two only have float64 implementations
200+ if how == "median" :
201+ # median only has a float64 implementation
225202 # We should only get here with is_numeric, as non-numeric cases
226203 # should raise in _get_cython_function
227204 values = ensure_float64 (values )
@@ -293,7 +270,7 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
293270
294271 out_shape : Shape
295272 if how == "ohlc" :
296- out_shape = (ngroups , 4 )
273+ out_shape = (ngroups , arity )
297274 elif arity > 1 :
298275 raise NotImplementedError (
299276 "arity of more than 1 is not supported for the 'how' argument"
@@ -342,9 +319,6 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
342319 return np .dtype (np .float64 )
343320 return dtype
344321
345- def uses_mask (self ) -> bool :
346- return self .how in self ._MASKED_CYTHON_FUNCTIONS
347-
348322 @final
349323 def _ea_wrap_cython_operation (
350324 self ,
@@ -358,7 +332,7 @@ def _ea_wrap_cython_operation(
358332 If we have an ExtensionArray, unwrap, call _cython_operation, and
359333 re-wrap if appropriate.
360334 """
361- if isinstance (values , BaseMaskedArray ) and self . uses_mask () :
335+ if isinstance (values , BaseMaskedArray ):
362336 return self ._masked_ea_wrap_cython_operation (
363337 values ,
364338 min_count = min_count ,
@@ -367,7 +341,7 @@ def _ea_wrap_cython_operation(
367341 ** kwargs ,
368342 )
369343
370- elif isinstance (values , Categorical ) and self . uses_mask () :
344+ elif isinstance (values , Categorical ):
371345 assert self .how == "rank" # the only one implemented ATM
372346 assert values .ordered # checked earlier
373347 mask = values .isna ()
@@ -398,7 +372,7 @@ def _ea_wrap_cython_operation(
398372 )
399373
400374 if self .how in self .cast_blocklist :
401- # i.e. how in ["rank"], since other cast_blocklist methods dont go
375+ # i.e. how in ["rank"], since other cast_blocklist methods don't go
402376 # through cython_operation
403377 return res_values
404378
@@ -411,12 +385,6 @@ def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray:
411385 # All of the functions implemented here are ordinal, so we can
412386 # operate on the tz-naive equivalents
413387 npvalues = values ._ndarray .view ("M8[ns]" )
414- elif isinstance (values .dtype , (BooleanDtype , IntegerDtype )):
415- # IntegerArray or BooleanArray
416- npvalues = values .to_numpy ("float64" , na_value = np .nan )
417- elif isinstance (values .dtype , FloatingDtype ):
418- # FloatingArray
419- npvalues = values .to_numpy (values .dtype .numpy_dtype , na_value = np .nan )
420388 elif isinstance (values .dtype , StringDtype ):
421389 # StringArray
422390 npvalues = values .to_numpy (object , na_value = np .nan )
@@ -440,12 +408,6 @@ def _reconstruct_ea_result(
440408 string_array_cls = dtype .construct_array_type ()
441409 return string_array_cls ._from_sequence (res_values , dtype = dtype )
442410
443- elif isinstance (values .dtype , BaseMaskedDtype ):
444- new_dtype = self ._get_result_dtype (values .dtype .numpy_dtype )
445- dtype = BaseMaskedDtype .from_numpy_dtype (new_dtype )
446- masked_array_cls = dtype .construct_array_type ()
447- return masked_array_cls ._from_sequence (res_values , dtype = dtype )
448-
449411 elif isinstance (values , (DatetimeArray , TimedeltaArray , PeriodArray )):
450412 # In to_cython_values we took a view as M8[ns]
451413 assert res_values .dtype == "M8[ns]"
@@ -489,7 +451,8 @@ def _masked_ea_wrap_cython_operation(
489451 )
490452
491453 if self .how == "ohlc" :
492- result_mask = np .tile (result_mask , (4 , 1 )).T
454+ arity = self ._cython_arity .get (self .how , 1 )
455+ result_mask = np .tile (result_mask , (arity , 1 )).T
493456
494457 # res_values should already have the correct dtype, we just need to
495458 # wrap in a MaskedArray
@@ -580,7 +543,7 @@ def _call_cython_op(
580543 result = maybe_fill (np .empty (out_shape , dtype = out_dtype ))
581544 if self .kind == "aggregate" :
582545 counts = np .zeros (ngroups , dtype = np .int64 )
583- if self .how in ["min" , "max" , "mean" , "last" , "first" ]:
546+ if self .how in ["min" , "max" , "mean" , "last" , "first" , "sum" ]:
584547 func (
585548 out = result ,
586549 counts = counts ,
@@ -591,18 +554,6 @@ def _call_cython_op(
591554 result_mask = result_mask ,
592555 is_datetimelike = is_datetimelike ,
593556 )
594- elif self .how in ["sum" ]:
595- # We support datetimelike
596- func (
597- out = result ,
598- counts = counts ,
599- values = values ,
600- labels = comp_ids ,
601- mask = mask ,
602- result_mask = result_mask ,
603- min_count = min_count ,
604- is_datetimelike = is_datetimelike ,
605- )
606557 elif self .how in ["var" , "ohlc" , "prod" , "median" ]:
607558 func (
608559 result ,
@@ -615,31 +566,21 @@ def _call_cython_op(
615566 ** kwargs ,
616567 )
617568 else :
618- func ( result , counts , values , comp_ids , min_count )
569+ raise NotImplementedError ( f" { self . how } is not implemented" )
619570 else :
620571 # TODO: min_count
621- if self .uses_mask ():
622- if self .how != "rank" :
623- # TODO: should rank take result_mask?
624- kwargs ["result_mask" ] = result_mask
625- func (
626- out = result ,
627- values = values ,
628- labels = comp_ids ,
629- ngroups = ngroups ,
630- is_datetimelike = is_datetimelike ,
631- mask = mask ,
632- ** kwargs ,
633- )
634- else :
635- func (
636- out = result ,
637- values = values ,
638- labels = comp_ids ,
639- ngroups = ngroups ,
640- is_datetimelike = is_datetimelike ,
641- ** kwargs ,
642- )
572+ if self .how != "rank" :
573+ # TODO: should rank take result_mask?
574+ kwargs ["result_mask" ] = result_mask
575+ func (
576+ out = result ,
577+ values = values ,
578+ labels = comp_ids ,
579+ ngroups = ngroups ,
580+ is_datetimelike = is_datetimelike ,
581+ mask = mask ,
582+ ** kwargs ,
583+ )
643584
644585 if self .kind == "aggregate" :
645586 # i.e. counts is defined. Locations where count<min_count
@@ -650,7 +591,7 @@ def _call_cython_op(
650591 cutoff = max (0 if self .how in ["sum" , "prod" ] else 1 , min_count )
651592 empty_groups = counts < cutoff
652593 if empty_groups .any ():
653- if result_mask is not None and self . uses_mask () :
594+ if result_mask is not None :
654595 assert result_mask [empty_groups ].all ()
655596 else :
656597 # Note: this conversion could be lossy, see GH#40767
0 commit comments