1010from dataclasses import dataclass
1111from typing import (
1212 Any ,
13- AsyncGenerator ,
1413 Callable ,
1514 DefaultDict ,
1615 Deque ,
1716 Dict ,
17+ Generator ,
1818 List ,
1919 Optional ,
2020 Set ,
4949)
5050from ray .data ._internal .execution .operators .sub_progress import SubProgressBarMixin
5151from ray .data ._internal .logical .interfaces import LogicalOperator
52+ from ray .data ._internal .output_buffer import BlockOutputBuffer , OutputBlockSizeOption
5253from ray .data ._internal .stats import OpRuntimeMetrics
5354from ray .data ._internal .table_block import TableBlockAccessor
5455from ray .data ._internal .util import GiB , MiB
@@ -836,9 +837,6 @@ def _on_bundle_ready(partition_id: int, bundle: RefBundle):
836837 self .reduce_metrics .on_task_output_generated (
837838 task_index = partition_id , output = bundle
838839 )
839- self .reduce_metrics .on_task_finished (
840- task_index = partition_id , exception = None
841- )
842840 _ , num_outputs , num_rows = estimate_total_num_of_blocks (
843841 partition_id + 1 ,
844842 self .upstream_op_num_outputs (),
@@ -857,6 +855,11 @@ def _on_aggregation_done(partition_id: int, exc: Optional[Exception]):
857855 if partition_id in self ._finalizing_tasks :
858856 self ._finalizing_tasks .pop (partition_id )
859857
858+ # Update Finalize Metrics on task completion
859+ self .reduce_metrics .on_task_finished (
860+ task_index = partition_id , exception = exc
861+ )
862+
860863 if exc :
861864 logger .error (
862865 f"Aggregation of the { partition_id } partition "
@@ -1335,7 +1338,12 @@ def start(self):
13351338
13361339 aggregator = HashShuffleAggregator .options (
13371340 ** self ._aggregator_ray_remote_args
1338- ).remote (aggregator_id , target_partition_ids , self ._aggregation_factory_ref )
1341+ ).remote (
1342+ aggregator_id ,
1343+ target_partition_ids ,
1344+ self ._aggregation_factory_ref ,
1345+ self ._data_context ,
1346+ )
13391347
13401348 self ._aggregators .append (aggregator )
13411349
@@ -1547,30 +1555,52 @@ def __init__(
15471555 aggregator_id : int ,
15481556 target_partition_ids : List [int ],
15491557 agg_factory : StatefulShuffleAggregationFactory ,
1558+ data_context : DataContext ,
15501559 ):
15511560 self ._lock = threading .Lock ()
15521561 self ._agg : StatefulShuffleAggregation = agg_factory (
15531562 aggregator_id , target_partition_ids
15541563 )
1564+ self ._data_context = data_context
15551565
15561566 def submit (self , input_seq_id : int , partition_id : int , partition_shard : Block ):
15571567 with self ._lock :
15581568 self ._agg .accept (input_seq_id , partition_id , partition_shard )
15591569
15601570 def finalize (
15611571 self , partition_id : int
1562- ) -> AsyncGenerator [Union [Block , "BlockMetadataWithSchema" ], None ]:
1572+ ) -> Generator [Union [Block , "BlockMetadataWithSchema" ], None , None ]:
1573+
15631574 with self ._lock :
1564- # Finalize given partition id
15651575 exec_stats_builder = BlockExecStats .builder ()
1576+ # Finalize given partition id
15661577 block = self ._agg .finalize (partition_id )
15671578 exec_stats = exec_stats_builder .build ()
15681579 # Clear any remaining state (to release resources)
15691580 self ._agg .clear (partition_id )
15701581
1571- # TODO break down blocks to target size
1572- yield block
1573- yield BlockMetadataWithSchema .from_block (block , stats = exec_stats )
1582+ target_max_block_size = self ._data_context .target_max_block_size
1583+ # None means the user wants to preserve the block distribution,
1584+ # so we do not break the block down further.
1585+ if target_max_block_size is not None :
1586+ # Creating a block output buffer per partition finalize task because
1587+ # retrying finalize tasks cause stateful output_bufer to be
1588+ # fragmented (ie, adding duplicated blocks, calling finalize 2x)
1589+ output_buffer = BlockOutputBuffer (
1590+ output_block_size_option = OutputBlockSizeOption (
1591+ target_max_block_size = target_max_block_size
1592+ )
1593+ )
1594+
1595+ output_buffer .add_block (block )
1596+ output_buffer .finalize ()
1597+ while output_buffer .has_next ():
1598+ block = output_buffer .next ()
1599+ yield block
1600+ yield BlockMetadataWithSchema .from_block (block , stats = exec_stats )
1601+ else :
1602+ yield block
1603+ yield BlockMetadataWithSchema .from_block (block , stats = exec_stats )
15741604
15751605
15761606def _get_total_cluster_resources () -> ExecutionResources :
0 commit comments