11import abc
22import math
3- from typing import TYPE_CHECKING , Any , Callable , List , Optional
3+ from typing import TYPE_CHECKING , Any , Callable , Generic , List , Optional
44
55import numpy as np
66import pyarrow .compute as pc
@@ -113,7 +113,7 @@ def _validate(self, schema: Optional["Schema"]) -> None:
113113
114114
115115@PublicAPI (stability = "alpha" )
116- class AggregateFnV2 (AggregateFn , abc .ABC ):
116+ class AggregateFnV2 (AggregateFn , abc .ABC , Generic [ AggType , U ] ):
117117 """Provides an interface to implement efficient aggregations to be applied
118118 to the dataset.
119119
@@ -254,7 +254,7 @@ def _validate(self, schema: Optional["Schema"]) -> None:
254254
255255
256256@PublicAPI
257- class Count (AggregateFnV2 ):
257+ class Count (AggregateFnV2 [ int , int ] ):
258258 """Defines count aggregation.
259259
260260 Example:
@@ -303,7 +303,7 @@ def __init__(
303303 zero_factory = lambda : 0 ,
304304 )
305305
306- def aggregate_block (self , block : Block ) -> AggType :
306+ def aggregate_block (self , block : Block ) -> int :
307307 block_accessor = BlockAccessor .for_block (block )
308308
309309 if self ._target_col_name is None :
@@ -314,12 +314,12 @@ def aggregate_block(self, block: Block) -> AggType:
314314 self ._target_col_name , ignore_nulls = self ._ignore_nulls
315315 )
316316
317- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
317+ def combine (self , current_accumulator : int , new : int ) -> int :
318318 return current_accumulator + new
319319
320320
321321@PublicAPI
322- class Sum (AggregateFnV2 ):
322+ class Sum (AggregateFnV2 [ int , int ] ):
323323 """Defines sum aggregation.
324324
325325 Example:
@@ -359,17 +359,17 @@ def __init__(
359359 zero_factory = lambda : 0 ,
360360 )
361361
362- def aggregate_block (self , block : Block ) -> AggType :
362+ def aggregate_block (self , block : Block ) -> int :
363363 return BlockAccessor .for_block (block ).sum (
364364 self ._target_col_name , self ._ignore_nulls
365365 )
366366
367- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
367+ def combine (self , current_accumulator : int , new : int ) -> int :
368368 return current_accumulator + new
369369
370370
371371@PublicAPI
372- class Min (AggregateFnV2 ):
372+ class Min (AggregateFnV2 [ int , int ] ):
373373 """Defines min aggregation.
374374
375375 Example:
@@ -412,17 +412,17 @@ def __init__(
412412 zero_factory = lambda : float ("+inf" ),
413413 )
414414
415- def aggregate_block (self , block : Block ) -> AggType :
415+ def aggregate_block (self , block : Block ) -> int :
416416 return BlockAccessor .for_block (block ).min (
417417 self ._target_col_name , self ._ignore_nulls
418418 )
419419
420- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
420+ def combine (self , current_accumulator : int , new : int ) -> int :
421421 return min (current_accumulator , new )
422422
423423
424424@PublicAPI
425- class Max (AggregateFnV2 ):
425+ class Max (AggregateFnV2 [ int , int ] ):
426426 """Defines max aggregation.
427427
428428 Example:
@@ -466,17 +466,17 @@ def __init__(
466466 zero_factory = lambda : float ("-inf" ),
467467 )
468468
469- def aggregate_block (self , block : Block ) -> AggType :
469+ def aggregate_block (self , block : Block ) -> int :
470470 return BlockAccessor .for_block (block ).max (
471471 self ._target_col_name , self ._ignore_nulls
472472 )
473473
474- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
474+ def combine (self , current_accumulator : int , new : int ) -> int :
475475 return max (current_accumulator , new )
476476
477477
478478@PublicAPI
479- class Mean (AggregateFnV2 ):
479+ class Mean (AggregateFnV2 [ List [ int ], float ] ):
480480 """Defines mean (average) aggregation.
481481
482482 Example:
@@ -521,7 +521,7 @@ def __init__(
521521 zero_factory = lambda : list ([0 , 0 ]), # noqa: C410
522522 )
523523
524- def aggregate_block (self , block : Block ) -> AggType :
524+ def aggregate_block (self , block : Block ) -> List [ int ] :
525525 block_acc = BlockAccessor .for_block (block )
526526 count = block_acc .count (self ._target_col_name , self ._ignore_nulls )
527527
@@ -539,10 +539,10 @@ def aggregate_block(self, block: Block) -> AggType:
539539
540540 return [sum_ , count ]
541541
542- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
542+ def combine (self , current_accumulator : List [ int ] , new : List [ int ] ) -> List [ int ] :
543543 return [current_accumulator [0 ] + new [0 ], current_accumulator [1 ] + new [1 ]]
544544
545- def finalize (self , accumulator : AggType ) -> Optional [U ]:
545+ def finalize (self , accumulator : List [ int ] ) -> Optional [float ]:
546546 # The final accumulator for a group is [total_sum, total_count].
547547 if accumulator [1 ] == 0 :
548548 # If total_count is 0 (e.g., group was empty or all nulls ignored),
@@ -553,7 +553,7 @@ def finalize(self, accumulator: AggType) -> Optional[U]:
553553
554554
555555@PublicAPI
556- class Std (AggregateFnV2 ):
556+ class Std (AggregateFnV2 [ List [ float ], float ] ):
557557 """Defines standard deviation aggregation.
558558
559559 Uses Welford's online algorithm for numerical stability. This method computes
@@ -610,7 +610,7 @@ def __init__(
610610
611611 self ._ddof = ddof
612612
613- def aggregate_block (self , block : Block ) -> AggType :
613+ def aggregate_block (self , block : Block ) -> List [ float ] :
614614 block_acc = BlockAccessor .for_block (block )
615615 count = block_acc .count (self ._target_col_name , ignore_nulls = self ._ignore_nulls )
616616 if count == 0 or count is None :
@@ -627,7 +627,7 @@ def aggregate_block(self, block: Block) -> AggType:
627627 )
628628 return [M2 , mean , count ]
629629
630- def combine (self , current_accumulator : List [float ], new : List [float ]) -> AggType :
630+ def combine (self , current_accumulator : List [float ], new : List [float ]) -> List [ float ] :
631631 # Merges two accumulators [M2, mean, count] using a parallel algorithm.
632632 # See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
633633 M2_a , mean_a , count_a = current_accumulator
@@ -643,7 +643,7 @@ def combine(self, current_accumulator: List[float], new: List[float]) -> AggType
643643 M2 = M2_a + M2_b + (delta ** 2 ) * count_a * count_b / count
644644 return [M2 , mean , count ]
645645
646- def finalize (self , accumulator : List [float ]) -> Optional [U ]:
646+ def finalize (self , accumulator : List [float ]) -> Optional [float ]:
647647 # Compute the final standard deviation from the accumulated
648648 # sum of squared differences from current mean and the count.
649649 # Final accumulator: [M2, mean, count]
@@ -658,7 +658,7 @@ def finalize(self, accumulator: List[float]) -> Optional[U]:
658658
659659
660660@PublicAPI
661- class AbsMax (AggregateFnV2 ):
661+ class AbsMax (AggregateFnV2 [ int , int ] ):
662662 """Defines absolute max aggregation.
663663
664664 Example:
@@ -701,7 +701,7 @@ def __init__(
701701 zero_factory = lambda : 0 ,
702702 )
703703
704- def aggregate_block (self , block : Block ) -> AggType :
704+ def aggregate_block (self , block : Block ) -> int :
705705 block_accessor = BlockAccessor .for_block (block )
706706
707707 max_ = block_accessor .max (self ._target_col_name , self ._ignore_nulls )
@@ -715,12 +715,12 @@ def aggregate_block(self, block: Block) -> AggType:
715715 abs (min_ ),
716716 )
717717
718- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
718+ def combine (self , current_accumulator : int , new : int ) -> int :
719719 return max (current_accumulator , new )
720720
721721
722722@PublicAPI
723- class Quantile (AggregateFnV2 ):
723+ class Quantile (AggregateFnV2 [ List [ Any ], List [ Any ]] ):
724724 """Defines Quantile aggregation.
725725
726726 Example:
@@ -790,7 +790,7 @@ def combine(self, current_accumulator: List[Any], new: List[Any]) -> List[Any]:
790790
791791 return ls
792792
793- def aggregate_block (self , block : Block ) -> AggType :
793+ def aggregate_block (self , block : Block ) -> List [ Any ] :
794794 block_acc = BlockAccessor .for_block (block )
795795 ls = []
796796
@@ -799,7 +799,7 @@ def aggregate_block(self, block: Block) -> AggType:
799799
800800 return ls
801801
802- def finalize (self , accumulator : List [Any ]) -> Optional [U ]:
802+ def finalize (self , accumulator : List [Any ]) -> Optional [Any ]:
803803 if self ._ignore_nulls :
804804 accumulator = [v for v in accumulator if not is_null (v )]
805805 else :
@@ -831,7 +831,7 @@ def finalize(self, accumulator: List[Any]) -> Optional[U]:
831831
832832
833833@PublicAPI
834- class Unique (AggregateFnV2 ):
834+ class Unique (AggregateFnV2 [ set , set ] ):
835835 """Defines unique aggregation.
836836
837837 Example:
@@ -870,10 +870,10 @@ def __init__(
870870 zero_factory = set ,
871871 )
872872
873- def combine (self , current_accumulator : AggType , new : AggType ) -> AggType :
873+ def combine (self , current_accumulator : set , new : set ) -> set :
874874 return self ._to_set (current_accumulator ) | self ._to_set (new )
875875
876- def aggregate_block (self , block : Block ) -> AggType :
876+ def aggregate_block (self , block : Block ) -> set :
877877 import pyarrow .compute as pac
878878
879879 col = BlockAccessor .for_block (block ).to_arrow ().column (self ._target_col_name )
@@ -1013,7 +1013,7 @@ def _safe_combine(
10131013
10141014
10151015@PublicAPI (stability = "alpha" )
1016- class MissingValuePercentage (AggregateFnV2 ):
1016+ class MissingValuePercentage (AggregateFnV2 [ List [ int ], float ] ):
10171017 """Calculates the percentage of null values in a column.
10181018
10191019 This aggregation computes the percentage of null (missing) values in a dataset column.
@@ -1094,7 +1094,7 @@ def finalize(self, accumulator: List[int]) -> Optional[float]:
10941094
10951095
10961096@PublicAPI (stability = "alpha" )
1097- class ZeroPercentage (AggregateFnV2 ):
1097+ class ZeroPercentage (AggregateFnV2 [ List [ int ], float ] ):
10981098 """Calculates the percentage of zero values in a numeric column.
10991099
11001100 This aggregation computes the percentage of zero values in a numeric dataset column.
0 commit comments