Skip to content

Commit 3a05b2a

Browse files
committed
style: make AggregateFnV2 generic over accumulator/result (Generic[AggType, U])
Signed-off-by: Arthur <atte.book@gmail.com>
1 parent 6351471 commit 3a05b2a

File tree

2 files changed

+52
-43
lines changed

2 files changed

+52
-43
lines changed

doc/source/conf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,9 @@
3232
# If extensions (or modules to document with autodoc) are in another directory,
3333
# add these directories to sys.path here. If the directory is relative to the
3434
# documentation root, use os.path.abspath to make it absolute, like shown here.
35-
assert not os.path.exists("../../python/ray/_raylet.so"), (
36-
"_raylet.so should not be imported for the purpose for doc build, "
37-
"please rename the file to _raylet.so.bak and try again."
38-
)
35+
assert not os.path.exists(
36+
"../../python/ray/_raylet.so"
37+
), "_raylet.so should not be imported for the purpose for doc build, please rename the file to _raylet.so.bak and try again."
3938
sys.path.insert(0, os.path.abspath("../../python/"))
4039

4140
# -- General configuration ------------------------------------------------
@@ -130,6 +129,8 @@
130129
nitpicky = True
131130
nitpick_ignore_regex = [
132131
("py:obj", "ray.actor.T"),
132+
("py:obj", "ray.data.aggregate.U"),
133+
("py:obj", "ray.data.aggregate.AggType"),
133134
("py:class", ".*"),
134135
# Workaround for https://github.com/sphinx-doc/sphinx/issues/10974
135136
("py:obj", "ray\\.data\\.datasource\\.datasink\\.WriteReturnType"),

python/ray/data/aggregate.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11
import abc
22
import math
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional
3+
from typing import TYPE_CHECKING, Any, Callable, Generic, List, Optional, TypeVar, Union
44

55
import numpy as np
66
import pyarrow.compute as pc
77

88
from ray.data._internal.util import is_null
99
from ray.data.block import (
10-
AggType,
1110
Block,
1211
BlockAccessor,
1312
BlockColumnAccessor,
1413
KeyType,
15-
T,
16-
U,
1714
)
1815
from ray.util.annotations import Deprecated, PublicAPI
1916

2017
if TYPE_CHECKING:
2118
from ray.data.dataset import Schema
2219

20+
T = TypeVar("T")
21+
U = TypeVar("U")
22+
AggType = TypeVar("AggType")
23+
2324

2425
@Deprecated(message="AggregateFn is deprecated, please use AggregateFnV2")
2526
@PublicAPI
@@ -113,7 +114,7 @@ def _validate(self, schema: Optional["Schema"]) -> None:
113114

114115

115116
@PublicAPI(stability="alpha")
116-
class AggregateFnV2(AggregateFn, abc.ABC):
117+
class AggregateFnV2(AggregateFn, abc.ABC, Generic[AggType, U]):
117118
"""Provides an interface to implement efficient aggregations to be applied
118119
to the dataset.
119120
@@ -254,7 +255,7 @@ def _validate(self, schema: Optional["Schema"]) -> None:
254255

255256

256257
@PublicAPI
257-
class Count(AggregateFnV2):
258+
class Count(AggregateFnV2[int, int]):
258259
"""Defines count aggregation.
259260
260261
Example:
@@ -303,7 +304,7 @@ def __init__(
303304
zero_factory=lambda: 0,
304305
)
305306

306-
def aggregate_block(self, block: Block) -> AggType:
307+
def aggregate_block(self, block: Block) -> int:
307308
block_accessor = BlockAccessor.for_block(block)
308309

309310
if self._target_col_name is None:
@@ -314,12 +315,12 @@ def aggregate_block(self, block: Block) -> AggType:
314315
self._target_col_name, ignore_nulls=self._ignore_nulls
315316
)
316317

317-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
318+
def combine(self, current_accumulator: int, new: int) -> int:
318319
return current_accumulator + new
319320

320321

321322
@PublicAPI
322-
class Sum(AggregateFnV2):
323+
class Sum(AggregateFnV2[Union[int, float], Union[int, float]]):
323324
"""Defines sum aggregation.
324325
325326
Example:
@@ -359,17 +360,19 @@ def __init__(
359360
zero_factory=lambda: 0,
360361
)
361362

362-
def aggregate_block(self, block: Block) -> AggType:
363+
def aggregate_block(self, block: Block) -> Union[int, float]:
363364
return BlockAccessor.for_block(block).sum(
364365
self._target_col_name, self._ignore_nulls
365366
)
366367

367-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
368+
def combine(
369+
self, current_accumulator: Union[int, float], new: Union[int, float]
370+
) -> Union[int, float]:
368371
return current_accumulator + new
369372

370373

371374
@PublicAPI
372-
class Min(AggregateFnV2):
375+
class Min(AggregateFnV2[Union[int, float], Union[int, float]]):
373376
"""Defines min aggregation.
374377
375378
Example:
@@ -412,17 +415,19 @@ def __init__(
412415
zero_factory=lambda: float("+inf"),
413416
)
414417

415-
def aggregate_block(self, block: Block) -> AggType:
418+
def aggregate_block(self, block: Block) -> Union[int, float]:
416419
return BlockAccessor.for_block(block).min(
417420
self._target_col_name, self._ignore_nulls
418421
)
419422

420-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
423+
def combine(
424+
self, current_accumulator: Union[int, float], new: Union[int, float]
425+
) -> Union[int, float]:
421426
return min(current_accumulator, new)
422427

423428

424429
@PublicAPI
425-
class Max(AggregateFnV2):
430+
class Max(AggregateFnV2[Union[int, float], Union[int, float]]):
426431
"""Defines max aggregation.
427432
428433
Example:
@@ -458,25 +463,26 @@ def __init__(
458463
ignore_nulls: bool = True,
459464
alias_name: Optional[str] = None,
460465
):
461-
462466
super().__init__(
463467
alias_name if alias_name else f"max({str(on)})",
464468
on=on,
465469
ignore_nulls=ignore_nulls,
466470
zero_factory=lambda: float("-inf"),
467471
)
468472

469-
def aggregate_block(self, block: Block) -> AggType:
473+
def aggregate_block(self, block: Block) -> Union[int, float]:
470474
return BlockAccessor.for_block(block).max(
471475
self._target_col_name, self._ignore_nulls
472476
)
473477

474-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
478+
def combine(
479+
self, current_accumulator: Union[int, float], new: Union[int, float]
480+
) -> Union[int, float]:
475481
return max(current_accumulator, new)
476482

477483

478484
@PublicAPI
479-
class Mean(AggregateFnV2):
485+
class Mean(AggregateFnV2[List[int], float]):
480486
"""Defines mean (average) aggregation.
481487
482488
Example:
@@ -521,7 +527,7 @@ def __init__(
521527
zero_factory=lambda: list([0, 0]), # noqa: C410
522528
)
523529

524-
def aggregate_block(self, block: Block) -> AggType:
530+
def aggregate_block(self, block: Block) -> List[int]:
525531
block_acc = BlockAccessor.for_block(block)
526532
count = block_acc.count(self._target_col_name, self._ignore_nulls)
527533

@@ -539,10 +545,10 @@ def aggregate_block(self, block: Block) -> AggType:
539545

540546
return [sum_, count]
541547

542-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
548+
def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]:
543549
return [current_accumulator[0] + new[0], current_accumulator[1] + new[1]]
544550

545-
def finalize(self, accumulator: AggType) -> Optional[U]:
551+
def finalize(self, accumulator: List[int]) -> Optional[float]:
546552
# The final accumulator for a group is [total_sum, total_count].
547553
if accumulator[1] == 0:
548554
# If total_count is 0 (e.g., group was empty or all nulls ignored),
@@ -553,7 +559,7 @@ def finalize(self, accumulator: AggType) -> Optional[U]:
553559

554560

555561
@PublicAPI
556-
class Std(AggregateFnV2):
562+
class Std(AggregateFnV2[List[float], float]):
557563
"""Defines standard deviation aggregation.
558564
559565
Uses Welford's online algorithm for numerical stability. This method computes
@@ -610,7 +616,7 @@ def __init__(
610616

611617
self._ddof = ddof
612618

613-
def aggregate_block(self, block: Block) -> AggType:
619+
def aggregate_block(self, block: Block) -> List[float]:
614620
block_acc = BlockAccessor.for_block(block)
615621
count = block_acc.count(self._target_col_name, ignore_nulls=self._ignore_nulls)
616622
if count == 0 or count is None:
@@ -627,7 +633,9 @@ def aggregate_block(self, block: Block) -> AggType:
627633
)
628634
return [M2, mean, count]
629635

630-
def combine(self, current_accumulator: List[float], new: List[float]) -> AggType:
636+
def combine(
637+
self, current_accumulator: List[float], new: List[float]
638+
) -> List[float]:
631639
# Merges two accumulators [M2, mean, count] using a parallel algorithm.
632640
# See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
633641
M2_a, mean_a, count_a = current_accumulator
@@ -643,7 +651,7 @@ def combine(self, current_accumulator: List[float], new: List[float]) -> AggType
643651
M2 = M2_a + M2_b + (delta**2) * count_a * count_b / count
644652
return [M2, mean, count]
645653

646-
def finalize(self, accumulator: List[float]) -> Optional[U]:
654+
def finalize(self, accumulator: List[float]) -> Optional[float]:
647655
# Compute the final standard deviation from the accumulated
648656
# sum of squared differences from current mean and the count.
649657
# Final accumulator: [M2, mean, count]
@@ -658,7 +666,7 @@ def finalize(self, accumulator: List[float]) -> Optional[U]:
658666

659667

660668
@PublicAPI
661-
class AbsMax(AggregateFnV2):
669+
class AbsMax(AggregateFnV2[Union[int, float], Union[int, float]]):
662670
"""Defines absolute max aggregation.
663671
664672
Example:
@@ -701,7 +709,7 @@ def __init__(
701709
zero_factory=lambda: 0,
702710
)
703711

704-
def aggregate_block(self, block: Block) -> AggType:
712+
def aggregate_block(self, block: Block) -> Union[int, float]:
705713
block_accessor = BlockAccessor.for_block(block)
706714

707715
max_ = block_accessor.max(self._target_col_name, self._ignore_nulls)
@@ -715,12 +723,14 @@ def aggregate_block(self, block: Block) -> AggType:
715723
abs(min_),
716724
)
717725

718-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
726+
def combine(
727+
self, current_accumulator: Union[int, float], new: Union[int, float]
728+
) -> Union[int, float]:
719729
return max(current_accumulator, new)
720730

721731

722732
@PublicAPI
723-
class Quantile(AggregateFnV2):
733+
class Quantile(AggregateFnV2[List[Any], List[Any]]):
724734
"""Defines Quantile aggregation.
725735
726736
Example:
@@ -790,7 +800,7 @@ def combine(self, current_accumulator: List[Any], new: List[Any]) -> List[Any]:
790800

791801
return ls
792802

793-
def aggregate_block(self, block: Block) -> AggType:
803+
def aggregate_block(self, block: Block) -> List[Any]:
794804
block_acc = BlockAccessor.for_block(block)
795805
ls = []
796806

@@ -799,7 +809,7 @@ def aggregate_block(self, block: Block) -> AggType:
799809

800810
return ls
801811

802-
def finalize(self, accumulator: List[Any]) -> Optional[U]:
812+
def finalize(self, accumulator: List[Any]) -> Optional[Any]:
803813
if self._ignore_nulls:
804814
accumulator = [v for v in accumulator if not is_null(v)]
805815
else:
@@ -831,7 +841,7 @@ def finalize(self, accumulator: List[Any]) -> Optional[U]:
831841

832842

833843
@PublicAPI
834-
class Unique(AggregateFnV2):
844+
class Unique(AggregateFnV2[set, set]):
835845
"""Defines unique aggregation.
836846
837847
Example:
@@ -870,10 +880,10 @@ def __init__(
870880
zero_factory=set,
871881
)
872882

873-
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
883+
def combine(self, current_accumulator: set, new: set) -> set:
874884
return self._to_set(current_accumulator) | self._to_set(new)
875885

876-
def aggregate_block(self, block: Block) -> AggType:
886+
def aggregate_block(self, block: Block) -> set:
877887
import pyarrow.compute as pac
878888

879889
col = BlockAccessor.for_block(block).to_arrow().column(self._target_col_name)
@@ -988,7 +998,6 @@ def _null_safe_combine(
988998
def _safe_combine(
989999
cur: Optional[AggType], new: Optional[AggType]
9901000
) -> Optional[AggType]:
991-
9921001
if is_null(cur):
9931002
return new
9941003
elif is_null(new):
@@ -1001,7 +1010,6 @@ def _safe_combine(
10011010
def _safe_combine(
10021011
cur: Optional[AggType], new: Optional[AggType]
10031012
) -> Optional[AggType]:
1004-
10051013
if is_null(new):
10061014
return new
10071015
elif is_null(cur):
@@ -1013,7 +1021,7 @@ def _safe_combine(
10131021

10141022

10151023
@PublicAPI(stability="alpha")
1016-
class MissingValuePercentage(AggregateFnV2):
1024+
class MissingValuePercentage(AggregateFnV2[List[int], float]):
10171025
"""Calculates the percentage of null values in a column.
10181026
10191027
This aggregation computes the percentage of null (missing) values in a dataset column.
@@ -1094,7 +1102,7 @@ def finalize(self, accumulator: List[int]) -> Optional[float]:
10941102

10951103

10961104
@PublicAPI(stability="alpha")
1097-
class ZeroPercentage(AggregateFnV2):
1105+
class ZeroPercentage(AggregateFnV2[List[int], float]):
10981106
"""Calculates the percentage of zero values in a numeric column.
10991107
11001108
This aggregation computes the percentage of zero values in a numeric dataset column.

0 commit comments

Comments
 (0)