Skip to content

Commit

Permalink
[data] cleanup: use SortKey instead of mixed typing in aggregation (r…
Browse files Browse the repository at this point in the history
…ay-project#48697)

## Why are these changes needed?

This makes SortAggregate more consistent by unifying the API on the
SortKey object, similar to how SortTaskSpec is implemented.


## Related issue number

This is related to ray-project#42776 and
ray-project#42142


Signed-off-by: Richard Liaw <rliaw@berkeley.edu>
  • Loading branch information
richardliaw authored and JP-sDEV committed Nov 14, 2024
1 parent a48e818 commit 2bd8a4d
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 93 deletions.
67 changes: 29 additions & 38 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -502,13 +503,13 @@ def sort_and_partition(

return find_partitions(table, boundaries, sort_key)

def combine(self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]) -> Block:
def combine(self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block:
"""Combine rows with the same key into an accumulator.
This assumes the block is already sorted by key in ascending order.
Args:
key: A column name or list of column names.
sort_key: A column name or list of column names.
If this is ``None``, place all rows in a single group.
aggs: The aggregations to do.
Expand All @@ -519,18 +520,13 @@ def combine(self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]) -> Blo
aggregation.
If key is None then the k column is omitted.
"""
if key is not None and not isinstance(key, (str, list)):
raise ValueError(
"key must be a string, list of strings or None when aggregating "
"on Arrow blocks, but "
f"got: {type(key)}."
)
keys: List[str] = sort_key.get_columns()

def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
"""Creates an iterator over zero-copy group views."""
if key is None:
if not keys:
# Global aggregation consists of a single "group", so we short-circuit.
yield None, self.to_block()
yield tuple(), self.to_block()
return

start = end = 0
Expand All @@ -540,36 +536,33 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
try:
if next_row is None:
next_row = next(iter)
next_key = next_row[key]
while next_row[key] == next_key:
next_keys = next_row[keys]
while next_row[keys] == next_keys:
end += 1
try:
next_row = next(iter)
except StopIteration:
next_row = None
break
yield next_key, self.slice(start, end)
yield next_keys, self.slice(start, end)
start = end
except StopIteration:
break

builder = ArrowBlockBuilder()
for group_key, group_view in iter_groups():
for group_keys, group_view in iter_groups():
# Aggregate.
accumulators = [agg.init(group_key) for agg in aggs]
init_vals = group_keys
if len(group_keys) == 1:
init_vals = group_keys[0]

accumulators = [agg.init(init_vals) for agg in aggs]
for i in range(len(aggs)):
accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view)

# Build the row.
row = {}
if key is not None:
if isinstance(key, list):
keys = key
group_keys = group_key
else:
keys = [key]
group_keys = [group_key]

if keys:
for k, gk in zip(keys, group_keys):
row[k] = gk

Expand Down Expand Up @@ -608,7 +601,7 @@ def merge_sorted_blocks(
@staticmethod
def aggregate_combined_blocks(
blocks: List[Block],
key: Union[str, List[str]],
sort_key: "SortKey",
aggs: Tuple["AggregateFn"],
finalize: bool,
) -> Tuple[Block, BlockMetadata]:
Expand All @@ -619,7 +612,7 @@ def aggregate_combined_blocks(
Args:
blocks: A list of partially combined and sorted blocks.
key: The column name of key or None for global aggregation.
sort_key: The column name of key or None for global aggregation.
aggs: The aggregations to do.
finalize: Whether to finalize the aggregation. This is used as an
optimization for cases where we repeatedly combine partially
Expand All @@ -633,13 +626,13 @@ def aggregate_combined_blocks(
"""

stats = BlockExecStats.builder()
keys = sort_key.get_columns()

keys = key if isinstance(key, list) else [key]
key_fn = (
(lambda r: tuple(r[r._row.schema.names[: len(keys)]]))
if key is not None
else (lambda r: (0,))
)
def key_fn(r):
if keys:
return tuple(r[keys])
else:
return (0,)

# Handle blocks of different types.
blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow")
Expand All @@ -658,9 +651,7 @@ def aggregate_combined_blocks(
if next_row is None:
next_row = next(iter)
next_keys = key_fn(next_row)
next_key_names = (
next_row._row.schema.names[: len(keys)] if key is not None else None
)
next_key_columns = keys

def gen():
nonlocal iter
Expand Down Expand Up @@ -699,9 +690,9 @@ def gen():
)
# Build the row.
row = {}
if key is not None:
for next_key, next_key_name in zip(next_keys, next_key_names):
row[next_key_name] = next_key
if keys:
for col_name, next_key in zip(next_key_columns, next_keys):
row[col_name] = next_key

for agg, agg_name, accumulator in zip(
aggs, resolved_agg_names, accumulators
Expand Down
70 changes: 32 additions & 38 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -415,14 +416,14 @@ def sort_and_partition(
return find_partitions(table, boundaries, sort_key)

def combine(
self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]
self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]
) -> "pandas.DataFrame":
"""Combine rows with the same key into an accumulator.
This assumes the block is already sorted by key in ascending order.
Args:
key: A column name or list of column names.
sort_key: A SortKey object which holds column names/keys.
If this is ``None``, place all rows in a single group.
aggs: The aggregations to do.
Expand All @@ -433,18 +434,14 @@ def combine(
aggregation.
If key is None then the k column is omitted.
"""
if key is not None and not isinstance(key, (str, list)):
raise ValueError(
"key must be a string, list of strings or None when aggregating "
"on Pandas blocks, but "
f"got: {type(key)}."
)
keys: List[str] = sort_key.get_columns()
pd = lazy_import_pandas()

def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
"""Creates an iterator over zero-copy group views."""
if key is None:
if not keys:
# Global aggregation consists of a single "group", so we short-circuit.
yield None, self.to_block()
yield tuple(), self.to_block()
return

start = end = 0
Expand All @@ -454,36 +451,34 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
try:
if next_row is None:
next_row = next(iter)
next_key = next_row[key]
while np.all(next_row[key] == next_key):
next_keys = next_row[keys]
while np.all(next_row[keys] == next_keys):
end += 1
try:
next_row = next(iter)
except StopIteration:
next_row = None
break
yield next_key, self.slice(start, end, copy=False)
if isinstance(next_keys, pd.Series):
next_keys = next_keys.values
yield next_keys, self.slice(start, end, copy=False)
start = end
except StopIteration:
break

builder = PandasBlockBuilder()
for group_key, group_view in iter_groups():
for group_keys, group_view in iter_groups():
# Aggregate.
accumulators = [agg.init(group_key) for agg in aggs]
init_vals = group_keys
if len(group_keys) == 1:
init_vals = group_keys[0]
accumulators = [agg.init(init_vals) for agg in aggs]
for i in range(len(aggs)):
accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view)

# Build the row.
row = {}
if key is not None:
if isinstance(key, list):
keys = key
group_keys = group_key
else:
keys = [key]
group_keys = [group_key]

if keys:
for k, gk in zip(keys, group_keys):
row[k] = gk

Expand Down Expand Up @@ -520,7 +515,7 @@ def merge_sorted_blocks(
@staticmethod
def aggregate_combined_blocks(
blocks: List["pandas.DataFrame"],
key: Union[str, List[str]],
sort_key: "SortKey",
aggs: Tuple["AggregateFn"],
finalize: bool,
) -> Tuple["pandas.DataFrame", BlockMetadata]:
Expand All @@ -531,7 +526,7 @@ def aggregate_combined_blocks(
Args:
blocks: A list of partially combined and sorted blocks.
key: The column name of key or None for global aggregation.
sort_key: The column name of key or None for global aggregation.
aggs: The aggregations to do.
finalize: Whether to finalize the aggregation. This is used as an
optimization for cases where we repeatedly combine partially
Expand All @@ -545,12 +540,13 @@ def aggregate_combined_blocks(
"""

stats = BlockExecStats.builder()
keys = key if isinstance(key, list) else [key]
key_fn = (
(lambda r: tuple(r[r._row.columns[: len(keys)]]))
if key is not None
else (lambda r: (0,))
)
keys = sort_key.get_columns()

def key_fn(r):
if keys:
return tuple(r[keys])
else:
return (0,)

# Handle blocks of different types.
blocks = TableBlockAccessor.normalize_block_types(blocks, "pandas")
Expand All @@ -569,9 +565,7 @@ def aggregate_combined_blocks(
if next_row is None:
next_row = next(iter)
next_keys = key_fn(next_row)
next_key_names = (
next_row._row.columns[: len(keys)] if key is not None else None
)
next_key_columns = keys

def gen():
nonlocal iter
Expand Down Expand Up @@ -610,9 +604,9 @@ def gen():
)
# Build the row.
row = {}
if key is not None:
for next_key, next_key_name in zip(next_keys, next_key_names):
row[next_key_name] = next_key
if keys:
for col_name, next_key in zip(next_key_columns, next_keys):
row[col_name] = next_key

for agg, agg_name, accumulator in zip(
aggs, resolved_agg_names, accumulators
Expand Down
10 changes: 6 additions & 4 deletions python/ray/data/_internal/planner/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

from ray.data._internal.execution.interfaces import (
AllToAllTransformFn,
Expand All @@ -22,7 +22,7 @@


def generate_aggregate_fn(
key: Optional[str],
key: Optional[Union[str, List[str]]],
aggs: List[AggregateFn],
batch_format: str,
_debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
Expand Down Expand Up @@ -50,6 +50,8 @@ def fn(

num_mappers = len(blocks)

sort_key = SortKey(key)

if key is None:
num_outputs = 1
boundaries = []
Expand All @@ -61,12 +63,12 @@ def fn(
]
# Sample boundaries for aggregate key.
boundaries = SortTaskSpec.sample_boundaries(
blocks, SortKey(key), num_outputs, sample_bar
blocks, sort_key, num_outputs, sample_bar
)

agg_spec = SortAggregateTaskSpec(
boundaries=boundaries,
key=key,
key=sort_key,
aggs=aggs,
batch_format=batch_format,
)
Expand Down
Loading

0 comments on commit 2bd8a4d

Please sign in to comment.