diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index 598025325b13..3ea83b9f04cc 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -10,6 +10,7 @@ Iterator, List, Optional, + Sequence, Tuple, TypeVar, Union, @@ -502,13 +503,13 @@ def sort_and_partition( return find_partitions(table, boundaries, sort_key) - def combine(self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]) -> Block: + def combine(self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block: """Combine rows with the same key into an accumulator. This assumes the block is already sorted by key in ascending order. Args: - key: A column name or list of column names. + sort_key: A column name or list of column names. If this is ``None``, place all rows in a single group. aggs: The aggregations to do. @@ -519,18 +520,13 @@ def combine(self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]) -> Blo aggregation. If key is None then the k column is omitted. """ - if key is not None and not isinstance(key, (str, list)): - raise ValueError( - "key must be a string, list of strings or None when aggregating " - "on Arrow blocks, but " - f"got: {type(key)}." - ) + keys: List[str] = sort_key.get_columns() - def iter_groups() -> Iterator[Tuple[KeyType, Block]]: + def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]: """Creates an iterator over zero-copy group views.""" - if key is None: + if not keys: # Global aggregation consists of a single "group", so we short-circuit. - yield None, self.to_block() + yield tuple(), self.to_block() return start = end = 0 @@ -540,36 +536,33 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]: try: if next_row is None: next_row = next(iter) - next_key = next_row[key] - while next_row[key] == next_key: + next_keys = next_row[keys] + while next_row[keys] == next_keys: end += 1 try: next_row = next(iter) except StopIteration: next_row = None break - yield next_key, self.slice(start, end) + yield next_keys, self.slice(start, end) start = end except StopIteration: break builder = ArrowBlockBuilder() - for group_key, group_view in iter_groups(): + for group_keys, group_view in iter_groups(): # Aggregate. - accumulators = [agg.init(group_key) for agg in aggs] + init_vals = group_keys + if len(group_keys) == 1: + init_vals = group_keys[0] + + accumulators = [agg.init(init_vals) for agg in aggs] for i in range(len(aggs)): accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view) # Build the row. row = {} - if key is not None: - if isinstance(key, list): - keys = key - group_keys = group_key - else: - keys = [key] - group_keys = [group_key] - + if keys: for k, gk in zip(keys, group_keys): row[k] = gk @@ -608,7 +601,7 @@ def merge_sorted_blocks( @staticmethod def aggregate_combined_blocks( blocks: List[Block], - key: Union[str, List[str]], + sort_key: "SortKey", aggs: Tuple["AggregateFn"], finalize: bool, ) -> Tuple[Block, BlockMetadata]: @@ -619,7 +612,7 @@ def aggregate_combined_blocks( Args: blocks: A list of partially combined and sorted blocks. - key: The column name of key or None for global aggregation. + sort_key: The column name of key or None for global aggregation. aggs: The aggregations to do. finalize: Whether to finalize the aggregation. This is used as an optimization for cases where we repeatedly combine partially @@ -633,13 +626,13 @@ def aggregate_combined_blocks( """ stats = BlockExecStats.builder() + keys = sort_key.get_columns() - keys = key if isinstance(key, list) else [key] - key_fn = ( - (lambda r: tuple(r[r._row.schema.names[: len(keys)]])) - if key is not None - else (lambda r: (0,)) - ) + def key_fn(r): + if keys: + return tuple(r[keys]) + else: + return (0,) # Handle blocks of different types. blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow") @@ -658,9 +651,7 @@ def aggregate_combined_blocks( if next_row is None: next_row = next(iter) next_keys = key_fn(next_row) - next_key_names = ( - next_row._row.schema.names[: len(keys)] if key is not None else None - ) + next_key_columns = keys def gen(): nonlocal iter @@ -699,9 +690,9 @@ def gen(): ) # Build the row. row = {} - if key is not None: - for next_key, next_key_name in zip(next_keys, next_key_names): - row[next_key_name] = next_key + if keys: + for col_name, next_key in zip(next_key_columns, next_keys): + row[col_name] = next_key for agg, agg_name, accumulator in zip( aggs, resolved_agg_names, accumulators diff --git a/python/ray/data/_internal/pandas_block.py b/python/ray/data/_internal/pandas_block.py index 214aa19b78c1..04ff4a35a7e0 100644 --- a/python/ray/data/_internal/pandas_block.py +++ b/python/ray/data/_internal/pandas_block.py @@ -8,6 +8,7 @@ Iterator, List, Optional, + Sequence, Tuple, TypeVar, Union, @@ -415,14 +416,14 @@ def sort_and_partition( return find_partitions(table, boundaries, sort_key) def combine( - self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"] + self, sort_key: "SortKey", aggs: Tuple["AggregateFn"] ) -> "pandas.DataFrame": """Combine rows with the same key into an accumulator. This assumes the block is already sorted by key in ascending order. Args: - key: A column name or list of column names. + sort_key: A SortKey object which holds column names/keys. If this is ``None``, place all rows in a single group. aggs: The aggregations to do. @@ -433,18 +434,14 @@ def combine( aggregation. If key is None then the k column is omitted. """ - if key is not None and not isinstance(key, (str, list)): - raise ValueError( - "key must be a string, list of strings or None when aggregating " - "on Pandas blocks, but " - f"got: {type(key)}." - ) + keys: List[str] = sort_key.get_columns() + pd = lazy_import_pandas() - def iter_groups() -> Iterator[Tuple[KeyType, Block]]: + def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]: """Creates an iterator over zero-copy group views.""" - if key is None: + if not keys: # Global aggregation consists of a single "group", so we short-circuit. - yield None, self.to_block() + yield tuple(), self.to_block() return start = end = 0 @@ -454,36 +451,34 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]: try: if next_row is None: next_row = next(iter) - next_key = next_row[key] - while np.all(next_row[key] == next_key): + next_keys = next_row[keys] + while np.all(next_row[keys] == next_keys): end += 1 try: next_row = next(iter) except StopIteration: next_row = None break - yield next_key, self.slice(start, end, copy=False) + if isinstance(next_keys, pd.Series): + next_keys = next_keys.values + yield next_keys, self.slice(start, end, copy=False) start = end except StopIteration: break builder = PandasBlockBuilder() - for group_key, group_view in iter_groups(): + for group_keys, group_view in iter_groups(): # Aggregate. - accumulators = [agg.init(group_key) for agg in aggs] + init_vals = group_keys + if len(group_keys) == 1: + init_vals = group_keys[0] + accumulators = [agg.init(init_vals) for agg in aggs] for i in range(len(aggs)): accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view) # Build the row. row = {} - if key is not None: - if isinstance(key, list): - keys = key - group_keys = group_key - else: - keys = [key] - group_keys = [group_key] - + if keys: for k, gk in zip(keys, group_keys): row[k] = gk @@ -520,7 +515,7 @@ def merge_sorted_blocks( @staticmethod def aggregate_combined_blocks( blocks: List["pandas.DataFrame"], - key: Union[str, List[str]], + sort_key: "SortKey", aggs: Tuple["AggregateFn"], finalize: bool, ) -> Tuple["pandas.DataFrame", BlockMetadata]: @@ -531,7 +526,7 @@ def aggregate_combined_blocks( Args: blocks: A list of partially combined and sorted blocks. - key: The column name of key or None for global aggregation. + sort_key: The column name of key or None for global aggregation. aggs: The aggregations to do. finalize: Whether to finalize the aggregation. This is used as an optimization for cases where we repeatedly combine partially @@ -545,12 +540,13 @@ def aggregate_combined_blocks( """ stats = BlockExecStats.builder() - keys = key if isinstance(key, list) else [key] - key_fn = ( - (lambda r: tuple(r[r._row.columns[: len(keys)]])) - if key is not None - else (lambda r: (0,)) - ) + keys = sort_key.get_columns() + + def key_fn(r): + if keys: + return tuple(r[keys]) + else: + return (0,) # Handle blocks of different types. blocks = TableBlockAccessor.normalize_block_types(blocks, "pandas") @@ -569,9 +565,7 @@ def aggregate_combined_blocks( if next_row is None: next_row = next(iter) next_keys = key_fn(next_row) - next_key_names = ( - next_row._row.columns[: len(keys)] if key is not None else None - ) + next_key_columns = keys def gen(): nonlocal iter @@ -610,9 +604,9 @@ def gen(): ) # Build the row. row = {} - if key is not None: - for next_key, next_key_name in zip(next_keys, next_key_names): - row[next_key_name] = next_key + if keys: + for col_name, next_key in zip(next_key_columns, next_keys): + row[col_name] = next_key for agg, agg_name, accumulator in zip( aggs, resolved_agg_names, accumulators diff --git a/python/ray/data/_internal/planner/aggregate.py b/python/ray/data/_internal/planner/aggregate.py index 8f177add41d9..2d32719dc849 100644 --- a/python/ray/data/_internal/planner/aggregate.py +++ b/python/ray/data/_internal/planner/aggregate.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from ray.data._internal.execution.interfaces import ( AllToAllTransformFn, @@ -22,7 +22,7 @@ def generate_aggregate_fn( - key: Optional[str], + key: Optional[Union[str, List[str]]], aggs: List[AggregateFn], batch_format: str, _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None, @@ -50,6 +50,8 @@ def fn( num_mappers = len(blocks) + sort_key = SortKey(key) + if key is None: num_outputs = 1 boundaries = [] @@ -61,12 +63,12 @@ def fn( ] # Sample boundaries for aggregate key. boundaries = SortTaskSpec.sample_boundaries( - blocks, SortKey(key), num_outputs, sample_bar + blocks, sort_key, num_outputs, sample_bar ) agg_spec = SortAggregateTaskSpec( boundaries=boundaries, - key=key, + key=sort_key, aggs=aggs, batch_format=batch_format, ) diff --git a/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py b/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py index 7b0aa0dc7ad8..a07a35302ba0 100644 --- a/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union from ray.data._internal.aggregate import Count, _AggregateOnKeyBase from ray.data._internal.planner.exchange.interfaces import ExchangeTaskSpec @@ -27,7 +27,7 @@ class SortAggregateTaskSpec(ExchangeTaskSpec): def __init__( self, boundaries: List[KeyType], - key: Optional[str], + key: SortKey, aggs: List[AggregateFn], batch_format: str, ): @@ -42,26 +42,26 @@ def map( block: Block, output_num_blocks: int, boundaries: List[KeyType], - key: Union[str, List[str], None], + sort_key: SortKey, aggs: List[AggregateFn], ) -> List[Union[BlockMetadata, Block]]: stats = BlockExecStats.builder() - block = SortAggregateTaskSpec._prune_unused_columns(block, key, aggs) - if key is None: - partitions = [block] - else: + block = SortAggregateTaskSpec._prune_unused_columns(block, sort_key, aggs) + if sort_key.get_columns(): partitions = BlockAccessor.for_block(block).sort_and_partition( boundaries, - SortKey(key), + sort_key, ) - parts = [BlockAccessor.for_block(p).combine(key, aggs) for p in partitions] + else: + partitions = [block] + parts = [BlockAccessor.for_block(p).combine(sort_key, aggs) for p in partitions] meta = BlockAccessor.for_block(block).get_metadata(exec_stats=stats.build()) return parts + [meta] @staticmethod def reduce( - key: Optional[str], + key: SortKey, aggs: List[AggregateFn], batch_format: str, *mapper_outputs: List[Block], @@ -77,12 +77,13 @@ def reduce( @staticmethod def _prune_unused_columns( block: Block, - key: Union[str, List[str]], + sort_key: SortKey, aggs: Tuple[AggregateFn], ) -> Block: """Prune unused columns from block before aggregate.""" prune_columns = True columns = set() + key = sort_key.get_columns() if isinstance(key, str): columns.add(key) diff --git a/python/ray/data/block.py b/python/ray/data/block.py index 1f8c156c8578..15cf6b68b20c 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -437,7 +437,7 @@ def sort_and_partition( """Return a list of sorted partitions of this block.""" raise NotImplementedError - def combine(self, key: Optional[str], agg: "AggregateFn") -> Block: + def combine(self, key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block: """Combine rows with the same key into an accumulator.""" raise NotImplementedError @@ -450,7 +450,7 @@ def merge_sorted_blocks( @staticmethod def aggregate_combined_blocks( - blocks: List[Block], key: Optional[str], agg: "AggregateFn" + blocks: List[Block], sort_key: "SortKey", aggs: Tuple["AggregateFn"] ) -> Tuple[Block, BlockMetadata]: """Aggregate partially combined and sorted blocks.""" raise NotImplementedError diff --git a/python/ray/data/tests/test_all_to_all.py b/python/ray/data/tests/test_all_to_all.py index 2b0a72175cf9..a7b6ec823a9c 100644 --- a/python/ray/data/tests/test_all_to_all.py +++ b/python/ray/data/tests/test_all_to_all.py @@ -134,6 +134,12 @@ def test_groupby_arrow(ray_start_regular_shared, use_push_based_shuffle): assert agg_ds.count() == 0 +def test_groupby_none(ray_start_regular_shared): + ds = ray.data.range(10) + assert ds.groupby(None).min().take_all() == [{"min(id)": 0}] + assert ds.groupby(None).max().take_all() == [{"max(id)": 9}] + + def test_groupby_errors(ray_start_regular_shared): ds = ray.data.range(100) ds.groupby(None).count().show() # OK @@ -208,6 +214,44 @@ def create_large_data(group): ds.take(1) +@pytest.mark.parametrize("keys", ["A", ["A", "B"]]) +def test_agg_inputs(ray_start_regular_shared, keys): + xs = list(range(100)) + ds = ray.data.from_items([{"A": (x % 3), "B": x, "C": (x % 2)} for x in xs]) + + def check_init(k): + if len(keys) == 2: + assert isinstance(k, tuple), k + assert len(k) == 2 + elif len(keys) == 1: + assert isinstance(k, int) + return 1 + + def check_finalize(v): + assert v == 1 + + def check_accumulate_merge(a, r): + assert a == 1 + if isinstance(r, int): + return 1 + elif len(r) == 3: + assert all(x in r for x in ["A", "B", "C"]) + else: + assert False, r + return 1 + + output = ds.groupby(keys).aggregate( + AggregateFn( + init=check_init, + accumulate_row=check_accumulate_merge, + merge=check_accumulate_merge, + finalize=check_finalize, + name="foo", + ) + ) + output.take_all() + + def test_agg_errors(ray_start_regular_shared): from ray.data._internal.aggregate import Max