Skip to content
50 changes: 40 additions & 10 deletions python/ray/data/_internal/execution/operators/hash_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from dataclasses import dataclass
from typing import (
Any,
AsyncGenerator,
Callable,
DefaultDict,
Deque,
Dict,
Generator,
List,
Optional,
Set,
Expand Down Expand Up @@ -49,6 +49,7 @@
)
from ray.data._internal.execution.operators.sub_progress import SubProgressBarMixin
from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data._internal.output_buffer import BlockOutputBuffer, OutputBlockSizeOption
from ray.data._internal.stats import OpRuntimeMetrics
from ray.data._internal.table_block import TableBlockAccessor
from ray.data._internal.util import GiB, MiB
Expand Down Expand Up @@ -836,9 +837,6 @@ def _on_bundle_ready(partition_id: int, bundle: RefBundle):
self.reduce_metrics.on_task_output_generated(
task_index=partition_id, output=bundle
)
self.reduce_metrics.on_task_finished(
task_index=partition_id, exception=None
)
_, num_outputs, num_rows = estimate_total_num_of_blocks(
partition_id + 1,
self.upstream_op_num_outputs(),
Expand All @@ -857,6 +855,11 @@ def _on_aggregation_done(partition_id: int, exc: Optional[Exception]):
if partition_id in self._finalizing_tasks:
self._finalizing_tasks.pop(partition_id)

# Update Finalize Metrics on task completion
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by

self.reduce_metrics.on_task_finished(
task_index=partition_id, exception=exc
)

if exc:
logger.error(
f"Aggregation of the {partition_id} partition "
Expand Down Expand Up @@ -1335,7 +1338,12 @@ def start(self):

aggregator = HashShuffleAggregator.options(
**self._aggregator_ray_remote_args
).remote(aggregator_id, target_partition_ids, self._aggregation_factory_ref)
).remote(
aggregator_id,
target_partition_ids,
self._aggregation_factory_ref,
self._data_context,
)

self._aggregators.append(aggregator)

Expand Down Expand Up @@ -1547,30 +1555,52 @@ def __init__(
aggregator_id: int,
target_partition_ids: List[int],
agg_factory: StatefulShuffleAggregationFactory,
data_context: DataContext,
):
self._lock = threading.Lock()
self._agg: StatefulShuffleAggregation = agg_factory(
aggregator_id, target_partition_ids
)
self._data_context = data_context

def submit(self, input_seq_id: int, partition_id: int, partition_shard: Block):
with self._lock:
self._agg.accept(input_seq_id, partition_id, partition_shard)

def finalize(
self, partition_id: int
) -> AsyncGenerator[Union[Block, "BlockMetadataWithSchema"], None]:
) -> Generator[Union[Block, "BlockMetadataWithSchema"], None, None]:

with self._lock:
# Finalize given partition id
exec_stats_builder = BlockExecStats.builder()
# Finalize given partition id
block = self._agg.finalize(partition_id)
exec_stats = exec_stats_builder.build()
# Clear any remaining state (to release resources)
self._agg.clear(partition_id)

# TODO break down blocks to target size
yield block
yield BlockMetadataWithSchema.from_block(block, stats=exec_stats)
target_max_block_size = self._data_context.target_max_block_size
# None means the user wants to preserve the block distribution,
# so we do not break the block down further.
if target_max_block_size is not None:
# Creating a block output buffer per partition finalize task because
# retrying finalize tasks cause stateful output_bufer to be
# fragmented (ie, adding duplicated blocks, calling finalize 2x)
output_buffer = BlockOutputBuffer(
output_block_size_option=OutputBlockSizeOption(
target_max_block_size=target_max_block_size
)
)

output_buffer.add_block(block)
output_buffer.finalize()
while output_buffer.has_next():
block = output_buffer.next()
yield block
yield BlockMetadataWithSchema.from_block(block, stats=exec_stats)
else:
yield block
yield BlockMetadataWithSchema.from_block(block, stats=exec_stats)


def _get_total_cluster_resources() -> ExecutionResources:
Expand Down