Skip to content

Commit b2e09a6

Browse files
committed
style: make AggregateFnV2 generic over accumulator/result (Generic[AggType, U])
Signed-off-by: Arthur <atte.book@gmail.com>
1 parent 6351471 commit b2e09a6

File tree

1 file changed

+33
-33
lines changed

1 file changed

+33
-33
lines changed

python/ray/data/aggregate.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
import math
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional
3+
from typing import TYPE_CHECKING, Any, Callable, Generic, List, Optional
44

55
import numpy as np
66
import pyarrow.compute as pc
@@ -113,7 +113,7 @@ def _validate(self, schema: Optional["Schema"]) -> None:
113113

114114

115115
@PublicAPI(stability="alpha")
116-
class AggregateFnV2(AggregateFn, abc.ABC):
116+
class AggregateFnV2(AggregateFn, abc.ABC, Generic[AggType, U]):
117117
"""Provides an interface to implement efficient aggregations to be applied
118118
to the dataset.
119119
@@ -254,7 +254,7 @@ def _validate(self, schema: Optional["Schema"]) -> None:
254254

255255

256256
@PublicAPI
257-
class Count(AggregateFnV2):
257+
class Count(AggregateFnV2[int, int]):
258258
"""Defines count aggregation.
259259
260260
Example:
@@ -303,7 +303,7 @@ def __init__(
303303
zero_factory=lambda: 0,
304304
)
305305

306-
def aggregate_block(self, block: Block) -> AggType:
306+
def aggregate_block(self, block: Block) -> int:
307307
block_accessor = BlockAccessor.for_block(block)
308308

309309
if self._target_col_name is None:
@@ -314,12 +314,12 @@ def aggregate_block(self, block: Block) -> AggType:
314314
self._target_col_name, ignore_nulls=self._ignore_nulls
315315
)
316316

317-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
317+
def combine(self, current_accumulator: int, new: int) -> int:
318318
return current_accumulator + new
319319

320320

321321
@PublicAPI
322-
class Sum(AggregateFnV2):
322+
class Sum(AggregateFnV2[int, int]):
323323
"""Defines sum aggregation.
324324
325325
Example:
@@ -359,17 +359,17 @@ def __init__(
359359
zero_factory=lambda: 0,
360360
)
361361

362-
def aggregate_block(self, block: Block) -> AggType:
362+
def aggregate_block(self, block: Block) -> int:
363363
return BlockAccessor.for_block(block).sum(
364364
self._target_col_name, self._ignore_nulls
365365
)
366366

367-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
367+
def combine(self, current_accumulator: int, new: int) -> int:
368368
return current_accumulator + new
369369

370370

371371
@PublicAPI
372-
class Min(AggregateFnV2):
372+
class Min(AggregateFnV2[int, int]):
373373
"""Defines min aggregation.
374374
375375
Example:
@@ -412,17 +412,17 @@ def __init__(
412412
zero_factory=lambda: float("+inf"),
413413
)
414414

415-
def aggregate_block(self, block: Block) -> AggType:
415+
def aggregate_block(self, block: Block) -> int:
416416
return BlockAccessor.for_block(block).min(
417417
self._target_col_name, self._ignore_nulls
418418
)
419419

420-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
420+
def combine(self, current_accumulator: int, new: int) -> int:
421421
return min(current_accumulator, new)
422422

423423

424424
@PublicAPI
425-
class Max(AggregateFnV2):
425+
class Max(AggregateFnV2[int, int]):
426426
"""Defines max aggregation.
427427
428428
Example:
@@ -466,17 +466,17 @@ def __init__(
466466
zero_factory=lambda: float("-inf"),
467467
)
468468

469-
def aggregate_block(self, block: Block) -> AggType:
469+
def aggregate_block(self, block: Block) -> int:
470470
return BlockAccessor.for_block(block).max(
471471
self._target_col_name, self._ignore_nulls
472472
)
473473

474-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
474+
def combine(self, current_accumulator: int, new: int) -> int:
475475
return max(current_accumulator, new)
476476

477477

478478
@PublicAPI
479-
class Mean(AggregateFnV2):
479+
class Mean(AggregateFnV2[List[int], float]):
480480
"""Defines mean (average) aggregation.
481481
482482
Example:
@@ -521,7 +521,7 @@ def __init__(
521521
zero_factory=lambda: list([0, 0]), # noqa: C410
522522
)
523523

524-
def aggregate_block(self, block: Block) -> AggType:
524+
def aggregate_block(self, block: Block) -> List[int]:
525525
block_acc = BlockAccessor.for_block(block)
526526
count = block_acc.count(self._target_col_name, self._ignore_nulls)
527527

@@ -539,10 +539,10 @@ def aggregate_block(self, block: Block) -> AggType:
539539

540540
return [sum_, count]
541541

542-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
542+
def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]:
543543
return [current_accumulator[0] + new[0], current_accumulator[1] + new[1]]
544544

545-
def finalize(self, accumulator: AggType) -> Optional[U]:
545+
def finalize(self, accumulator: List[int]) -> Optional[float]:
546546
# The final accumulator for a group is [total_sum, total_count].
547547
if accumulator[1] == 0:
548548
# If total_count is 0 (e.g., group was empty or all nulls ignored),
@@ -553,7 +553,7 @@ def finalize(self, accumulator: AggType) -> Optional[U]:
553553

554554

555555
@PublicAPI
556-
class Std(AggregateFnV2):
556+
class Std(AggregateFnV2[List[float], float]):
557557
"""Defines standard deviation aggregation.
558558
559559
Uses Welford's online algorithm for numerical stability. This method computes
@@ -610,7 +610,7 @@ def __init__(
610610

611611
self._ddof = ddof
612612

613-
def aggregate_block(self, block: Block) -> AggType:
613+
def aggregate_block(self, block: Block) -> List[float]:
614614
block_acc = BlockAccessor.for_block(block)
615615
count = block_acc.count(self._target_col_name, ignore_nulls=self._ignore_nulls)
616616
if count == 0 or count is None:
@@ -627,7 +627,7 @@ def aggregate_block(self, block: Block) -> AggType:
627627
)
628628
return [M2, mean, count]
629629

630-
def combine(self, current_accumulator: List[float], new: List[float]) -> AggType:
630+
def combine(self, current_accumulator: List[float], new: List[float]) -> List[float]:
631631
# Merges two accumulators [M2, mean, count] using a parallel algorithm.
632632
# See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
633633
M2_a, mean_a, count_a = current_accumulator
@@ -643,7 +643,7 @@ def combine(self, current_accumulator: List[float], new: List[float]) -> AggType
643643
M2 = M2_a + M2_b + (delta**2) * count_a * count_b / count
644644
return [M2, mean, count]
645645

646-
def finalize(self, accumulator: List[float]) -> Optional[U]:
646+
def finalize(self, accumulator: List[float]) -> Optional[float]:
647647
# Compute the final standard deviation from the accumulated
648648
# sum of squared differences from current mean and the count.
649649
# Final accumulator: [M2, mean, count]
@@ -658,7 +658,7 @@ def finalize(self, accumulator: List[float]) -> Optional[U]:
658658

659659

660660
@PublicAPI
661-
class AbsMax(AggregateFnV2):
661+
class AbsMax(AggregateFnV2[int, int]):
662662
"""Defines absolute max aggregation.
663663
664664
Example:
@@ -701,7 +701,7 @@ def __init__(
701701
zero_factory=lambda: 0,
702702
)
703703

704-
def aggregate_block(self, block: Block) -> AggType:
704+
def aggregate_block(self, block: Block) -> int:
705705
block_accessor = BlockAccessor.for_block(block)
706706

707707
max_ = block_accessor.max(self._target_col_name, self._ignore_nulls)
@@ -715,12 +715,12 @@ def aggregate_block(self, block: Block) -> AggType:
715715
abs(min_),
716716
)
717717

718-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
718+
def combine(self, current_accumulator: int, new: int) -> int:
719719
return max(current_accumulator, new)
720720

721721

722722
@PublicAPI
723-
class Quantile(AggregateFnV2):
723+
class Quantile(AggregateFnV2[List[Any], List[Any]]):
724724
"""Defines Quantile aggregation.
725725
726726
Example:
@@ -790,7 +790,7 @@ def combine(self, current_accumulator: List[Any], new: List[Any]) -> List[Any]:
790790

791791
return ls
792792

793-
def aggregate_block(self, block: Block) -> AggType:
793+
def aggregate_block(self, block: Block) -> List[Any]:
794794
block_acc = BlockAccessor.for_block(block)
795795
ls = []
796796

@@ -799,7 +799,7 @@ def aggregate_block(self, block: Block) -> AggType:
799799

800800
return ls
801801

802-
def finalize(self, accumulator: List[Any]) -> Optional[U]:
802+
def finalize(self, accumulator: List[Any]) -> Optional[Any]:
803803
if self._ignore_nulls:
804804
accumulator = [v for v in accumulator if not is_null(v)]
805805
else:
@@ -831,7 +831,7 @@ def finalize(self, accumulator: List[Any]) -> Optional[U]:
831831

832832

833833
@PublicAPI
834-
class Unique(AggregateFnV2):
834+
class Unique(AggregateFnV2[set, set]):
835835
"""Defines unique aggregation.
836836
837837
Example:
@@ -870,10 +870,10 @@ def __init__(
870870
zero_factory=set,
871871
)
872872

873-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
873+
def combine(self, current_accumulator: set, new: set) -> set:
874874
return self._to_set(current_accumulator) | self._to_set(new)
875875

876-
def aggregate_block(self, block: Block) -> AggType:
876+
def aggregate_block(self, block: Block) -> set:
877877
import pyarrow.compute as pac
878878

879879
col = BlockAccessor.for_block(block).to_arrow().column(self._target_col_name)
@@ -1013,7 +1013,7 @@ def _safe_combine(
10131013

10141014

10151015
@PublicAPI(stability="alpha")
1016-
class MissingValuePercentage(AggregateFnV2):
1016+
class MissingValuePercentage(AggregateFnV2[List[int], float]):
10171017
"""Calculates the percentage of null values in a column.
10181018
10191019
This aggregation computes the percentage of null (missing) values in a dataset column.
@@ -1094,7 +1094,7 @@ def finalize(self, accumulator: List[int]) -> Optional[float]:
10941094

10951095

10961096
@PublicAPI(stability="alpha")
1097-
class ZeroPercentage(AggregateFnV2):
1097+
class ZeroPercentage(AggregateFnV2[List[int], float]):
10981098
"""Calculates the percentage of zero values in a numeric column.
10991099
11001100
This aggregation computes the percentage of zero values in a numeric dataset column.

0 commit comments

Comments
 (0)