11import abc
22import math
3- from typing import TYPE_CHECKING , Any , Callable , List , Optional
3+ from typing import TYPE_CHECKING , Any , Callable , Generic , List , Optional , TypeVar , Union
44
55import numpy as np
66import pyarrow .compute as pc
77
88from ray .data ._internal .util import is_null
99from ray .data .block import (
10- AggType ,
1110 Block ,
1211 BlockAccessor ,
1312 BlockColumnAccessor ,
1413 KeyType ,
15- T ,
16- U ,
1714)
1815from ray .util .annotations import Deprecated , PublicAPI
1916
2017if TYPE_CHECKING :
2118 from ray .data .dataset import Schema
2219
20+ T = TypeVar ("T" )
21+ U = TypeVar ("U" )
22+ AggType = TypeVar ("AggType" )
23+
2324
2425@Deprecated (message = "AggregateFn is deprecated, please use AggregateFnV2" )
2526@PublicAPI
@@ -113,7 +114,7 @@ def _validate(self, schema: Optional["Schema"]) -> None:
113114
114115
115116@PublicAPI (stability = "alpha" )
116- class AggregateFnV2 (AggregateFn , abc .ABC ):
117+ class AggregateFnV2 (AggregateFn , abc .ABC , Generic [ AggType , U ] ):
117118 """Provides an interface to implement efficient aggregations to be applied
118119 to the dataset.
119120
@@ -254,7 +255,7 @@ def _validate(self, schema: Optional["Schema"]) -> None:
254255
255256
256257@PublicAPI
257- class Count (AggregateFnV2 ):
258+ class Count (AggregateFnV2 [ int , int ] ):
258259 """Defines count aggregation.
259260
260261 Example:
@@ -303,7 +304,7 @@ def __init__(
303304 zero_factory = lambda : 0 ,
304305 )
305306
306- def aggregate_block (self , block : Block ) -> AggType :
307+ def aggregate_block (self , block : Block ) -> int :
307308 block_accessor = BlockAccessor .for_block (block )
308309
309310 if self ._target_col_name is None :
@@ -314,12 +315,12 @@ def aggregate_block(self, block: Block) -> AggType:
314315 self ._target_col_name , ignore_nulls = self ._ignore_nulls
315316 )
316317
317- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
318+ def combine (self , current_accumulator : int , new : int ) -> int :
318319 return current_accumulator + new
319320
320321
321322@PublicAPI
322- class Sum (AggregateFnV2 ):
323+ class Sum (AggregateFnV2 [ Union [ int , float ], Union [ int , float ]] ):
323324 """Defines sum aggregation.
324325
325326 Example:
@@ -359,17 +360,19 @@ def __init__(
359360 zero_factory = lambda : 0 ,
360361 )
361362
362- def aggregate_block (self , block : Block ) -> AggType :
363+ def aggregate_block (self , block : Block ) -> Union [ int , float ] :
363364 return BlockAccessor .for_block (block ).sum (
364365 self ._target_col_name , self ._ignore_nulls
365366 )
366367
367- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
368+ def combine (
369+ self , current_accumulator : Union [int , float ], new : Union [int , float ]
370+ ) -> Union [int , float ]:
368371 return current_accumulator + new
369372
370373
371374@PublicAPI
372- class Min (AggregateFnV2 ):
375+ class Min (AggregateFnV2 [ Union [ int , float ], Union [ int , float ]] ):
373376 """Defines min aggregation.
374377
375378 Example:
@@ -412,17 +415,19 @@ def __init__(
412415 zero_factory = lambda : float ("+inf" ),
413416 )
414417
415- def aggregate_block (self , block : Block ) -> AggType :
418+ def aggregate_block (self , block : Block ) -> Union [ int , float ] :
416419 return BlockAccessor .for_block (block ).min (
417420 self ._target_col_name , self ._ignore_nulls
418421 )
419422
420- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
423+ def combine (
424+ self , current_accumulator : Union [int , float ], new : Union [int , float ]
425+ ) -> Union [int , float ]:
421426 return min (current_accumulator , new )
422427
423428
424429@PublicAPI
425- class Max (AggregateFnV2 ):
430+ class Max (AggregateFnV2 [ Union [ int , float ], Union [ int , float ]] ):
426431 """Defines max aggregation.
427432
428433 Example:
@@ -458,25 +463,26 @@ def __init__(
458463 ignore_nulls : bool = True ,
459464 alias_name : Optional [str ] = None ,
460465 ):
461-
462466 super ().__init__ (
463467 alias_name if alias_name else f"max({ str (on )} )" ,
464468 on = on ,
465469 ignore_nulls = ignore_nulls ,
466470 zero_factory = lambda : float ("-inf" ),
467471 )
468472
469- def aggregate_block (self , block : Block ) -> AggType :
473+ def aggregate_block (self , block : Block ) -> Union [ int , float ] :
470474 return BlockAccessor .for_block (block ).max (
471475 self ._target_col_name , self ._ignore_nulls
472476 )
473477
474- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
478+ def combine (
479+ self , current_accumulator : Union [int , float ], new : Union [int , float ]
480+ ) -> Union [int , float ]:
475481 return max (current_accumulator , new )
476482
477483
478484@PublicAPI
479- class Mean (AggregateFnV2 ):
485+ class Mean (AggregateFnV2 [ List [ int ], float ] ):
480486 """Defines mean (average) aggregation.
481487
482488 Example:
@@ -521,7 +527,7 @@ def __init__(
521527 zero_factory = lambda : list ([0 , 0 ]), # noqa: C410
522528 )
523529
524- def aggregate_block (self , block : Block ) -> AggType :
530+ def aggregate_block (self , block : Block ) -> List [ int ] :
525531 block_acc = BlockAccessor .for_block (block )
526532 count = block_acc .count (self ._target_col_name , self ._ignore_nulls )
527533
@@ -539,10 +545,10 @@ def aggregate_block(self, block: Block) -> AggType:
539545
540546 return [sum_ , count ]
541547
542- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
548+ def combine (self , current_accumulator : List [ int ] , new : List [ int ] ) -> List [ int ] :
543549 return [current_accumulator [0 ] + new [0 ], current_accumulator [1 ] + new [1 ]]
544550
545- def finalize (self , accumulator : AggType ) -> Optional [U ]:
551+ def finalize (self , accumulator : List [ int ] ) -> Optional [float ]:
546552 # The final accumulator for a group is [total_sum, total_count].
547553 if accumulator [1 ] == 0 :
548554 # If total_count is 0 (e.g., group was empty or all nulls ignored),
@@ -553,7 +559,7 @@ def finalize(self, accumulator: AggType) -> Optional[U]:
553559
554560
555561@PublicAPI
556- class Std (AggregateFnV2 ):
562+ class Std (AggregateFnV2 [ List [ float ], float ] ):
557563 """Defines standard deviation aggregation.
558564
559565 Uses Welford's online algorithm for numerical stability. This method computes
@@ -610,7 +616,7 @@ def __init__(
610616
611617 self ._ddof = ddof
612618
613- def aggregate_block (self , block : Block ) -> AggType :
619+ def aggregate_block (self , block : Block ) -> List [ float ] :
614620 block_acc = BlockAccessor .for_block (block )
615621 count = block_acc .count (self ._target_col_name , ignore_nulls = self ._ignore_nulls )
616622 if count == 0 or count is None :
@@ -627,7 +633,9 @@ def aggregate_block(self, block: Block) -> AggType:
627633 )
628634 return [M2 , mean , count ]
629635
630- def combine (self , current_accumulator : List [float ], new : List [float ]) -> AggType :
636+ def combine (
637+ self , current_accumulator : List [float ], new : List [float ]
638+ ) -> List [float ]:
631639 # Merges two accumulators [M2, mean, count] using a parallel algorithm.
632640 # See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
633641 M2_a , mean_a , count_a = current_accumulator
@@ -643,7 +651,7 @@ def combine(self, current_accumulator: List[float], new: List[float]) -> AggType
643651 M2 = M2_a + M2_b + (delta ** 2 ) * count_a * count_b / count
644652 return [M2 , mean , count ]
645653
646- def finalize (self , accumulator : List [float ]) -> Optional [U ]:
654+ def finalize (self , accumulator : List [float ]) -> Optional [float ]:
647655 # Compute the final standard deviation from the accumulated
648656 # sum of squared differences from current mean and the count.
649657 # Final accumulator: [M2, mean, count]
@@ -658,7 +666,7 @@ def finalize(self, accumulator: List[float]) -> Optional[U]:
658666
659667
660668@PublicAPI
661- class AbsMax (AggregateFnV2 ):
669+ class AbsMax (AggregateFnV2 [ Union [ int , float ], Union [ int , float ]] ):
662670 """Defines absolute max aggregation.
663671
664672 Example:
@@ -701,7 +709,7 @@ def __init__(
701709 zero_factory = lambda : 0 ,
702710 )
703711
704- def aggregate_block (self , block : Block ) -> AggType :
712+ def aggregate_block (self , block : Block ) -> Union [ int , float ] :
705713 block_accessor = BlockAccessor .for_block (block )
706714
707715 max_ = block_accessor .max (self ._target_col_name , self ._ignore_nulls )
@@ -715,12 +723,14 @@ def aggregate_block(self, block: Block) -> AggType:
715723 abs (min_ ),
716724 )
717725
718- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
726+ def combine (
727+ self , current_accumulator : Union [int , float ], new : Union [int , float ]
728+ ) -> Union [int , float ]:
719729 return max (current_accumulator , new )
720730
721731
722732@PublicAPI
723- class Quantile (AggregateFnV2 ):
733+ class Quantile (AggregateFnV2 [ List [ Any ], List [ Any ]] ):
724734 """Defines Quantile aggregation.
725735
726736 Example:
@@ -790,7 +800,7 @@ def combine(self, current_accumulator: List[Any], new: List[Any]) -> List[Any]:
790800
791801 return ls
792802
793- def aggregate_block (self , block : Block ) -> AggType :
803+ def aggregate_block (self , block : Block ) -> List [ Any ] :
794804 block_acc = BlockAccessor .for_block (block )
795805 ls = []
796806
@@ -799,7 +809,7 @@ def aggregate_block(self, block: Block) -> AggType:
799809
800810 return ls
801811
802- def finalize (self , accumulator : List [Any ]) -> Optional [U ]:
812+ def finalize (self , accumulator : List [Any ]) -> Optional [Any ]:
803813 if self ._ignore_nulls :
804814 accumulator = [v for v in accumulator if not is_null (v )]
805815 else :
@@ -831,7 +841,7 @@ def finalize(self, accumulator: List[Any]) -> Optional[U]:
831841
832842
833843@PublicAPI
834- class Unique (AggregateFnV2 ):
844+ class Unique (AggregateFnV2 [ set , set ] ):
835845 """Defines unique aggregation.
836846
837847 Example:
@@ -870,10 +880,10 @@ def __init__(
870880 zero_factory = set ,
871881 )
872882
873- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
883+ def combine (self , current_accumulator : set , new : set ) -> set :
874884 return self ._to_set (current_accumulator ) | self ._to_set (new )
875885
876- def aggregate_block (self , block : Block ) -> AggType :
886+ def aggregate_block (self , block : Block ) -> set :
877887 import pyarrow .compute as pac
878888
879889 col = BlockAccessor .for_block (block ).to_arrow ().column (self ._target_col_name )
@@ -988,7 +998,6 @@ def _null_safe_combine(
988998 def _safe_combine (
989999 cur : Optional [AggType ], new : Optional [AggType ]
9901000 ) -> Optional [AggType ]:
991-
9921001 if is_null (cur ):
9931002 return new
9941003 elif is_null (new ):
@@ -1001,7 +1010,6 @@ def _safe_combine(
10011010 def _safe_combine (
10021011 cur : Optional [AggType ], new : Optional [AggType ]
10031012 ) -> Optional [AggType ]:
1004-
10051013 if is_null (new ):
10061014 return new
10071015 elif is_null (cur ):
@@ -1013,7 +1021,7 @@ def _safe_combine(
10131021
10141022
10151023@PublicAPI (stability = "alpha" )
1016- class MissingValuePercentage (AggregateFnV2 ):
1024+ class MissingValuePercentage (AggregateFnV2 [ List [ int ], float ] ):
10171025 """Calculates the percentage of null values in a column.
10181026
10191027 This aggregation computes the percentage of null (missing) values in a dataset column.
@@ -1094,7 +1102,7 @@ def finalize(self, accumulator: List[int]) -> Optional[float]:
10941102
10951103
10961104@PublicAPI (stability = "alpha" )
1097- class ZeroPercentage (AggregateFnV2 ):
1105+ class ZeroPercentage (AggregateFnV2 [ List [ int ], float ] ):
10981106 """Calculates the percentage of zero values in a numeric column.
10991107
11001108 This aggregation computes the percentage of zero values in a numeric dataset column.
0 commit comments