diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index f807ae2078dc..9dc052d12aaa 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -135,6 +135,7 @@ def __init__( if actor_prefetcher_enabled else WaitBlockPrefetcher() ) + self._yielded_first_batch = False def _prefetch_blocks( self, ref_bundles: Iterator[RefBundle] @@ -235,15 +236,29 @@ def __iter__(self) -> Iterator[DataBatch]: return self._iter_batches() def before_epoch_start(self): - pass + self._yielded_first_batch = False def after_epoch_end(self): StatsManager.clear_iteration_metrics(self._dataset_tag) @contextmanager def get_next_batch_context(self): - with self._stats.iter_total_blocked_s.timer() if self._stats else nullcontext(): - yield + try: + if self._stats: + # Always track total blocked time + total_timer = self._stats.iter_total_blocked_s.timer() + # Also track the time until the first batch is ready + first_batch_ready_timer = ( + self._stats.iter_time_to_first_batch_s.timer() + if not self._yielded_first_batch + else nullcontext() + ) + with total_timer, first_batch_ready_timer: + yield + else: + yield + finally: + self._yielded_first_batch = True @contextmanager def yield_batch_context(self, batch: Batch): diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index d00b45c89b8a..7f4222f68426 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -280,6 +280,12 @@ def __init__(self, max_stats=1000): description="Seconds user thread is blocked by iter_batches()", tag_keys=iter_tag_keys, ) + self.time_to_first_batch_s = Gauge( + "data_iter_time_to_first_batch_seconds", + description="Total time spent waiting for the first batch after starting iteration. " + "This includes the dataset pipeline warmup time. This metric is accumulated across different epochs.", + tag_keys=iter_tag_keys, + ) self.iter_user_s = Gauge( "data_iter_user_seconds", description="Seconds spent in user code", @@ -469,6 +475,7 @@ def update_iteration_metrics( ): tags = self._create_tags(dataset_tag) self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) + self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags) @@ -948,6 +955,7 @@ def __init__( self.iter_format_batch_s: Timer = Timer() self.iter_collate_batch_s: Timer = Timer() self.iter_finalize_batch_s: Timer = Timer() + self.iter_time_to_first_batch_s: Timer = Timer() self.iter_total_blocked_s: Timer = Timer() self.iter_user_s: Timer = Timer() self.iter_initialize_s: Timer = Timer() @@ -1003,6 +1011,7 @@ def to_summary(self) -> "DatasetStatsSummary": self.iter_format_batch_s, self.iter_collate_batch_s, self.iter_finalize_batch_s, + self.iter_time_to_first_batch_s, self.iter_total_blocked_s, self.iter_user_s, self.iter_initialize_s, @@ -1642,6 +1651,8 @@ class IterStatsSummary: collate_time: Timer # Time spent in finalize_fn, in seconds finalize_batch_time: Timer + # Time user thread is blocked waiting for first batch + time_to_first_batch: Timer # Total time user thread is blocked by iter_batches block_time: Timer # Time spent in user code, in seconds @@ -1665,6 +1676,7 @@ def to_string(self) -> str: out = "" if ( self.block_time.get() + or self.time_to_first_batch.get() or self.total_time.get() or self.get_time.get() or self.next_time.get() @@ -1685,6 +1697,11 @@ def to_string(self) -> str: " * Total time user thread is blocked by Ray Data iter_batches: " "{}\n".format(fmt(self.block_time.get())) ) + if self.time_to_first_batch.get(): + out += ( + " * Total time spent waiting for the first batch after starting iteration: " + "{}\n".format(fmt(self.time_to_first_batch.get())) + ) if self.user_time.get(): out += " * Total execution time for user thread: {}\n".format( fmt(self.user_time.get()) diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 0a3a32e9d63e..cbe5b5dfa8ba 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -395,6 +395,7 @@ def test_streaming_split_stats(ray_start_regular_shared, restore_data_context): * Total time overall: T * Total time in Ray Data iterator initialization code: T * Total time user thread is blocked by Ray Data iter_batches: T + * Total time spent waiting for the first batch after starting iteration: T * Total execution time for user thread: T * Batch iteration time breakdown (summed across prefetch threads): * In ray.get(): T min, T max, T avg, T total @@ -577,6 +578,7 @@ def test_dataset_stats_basic( f"* Total time overall: T\n" f" * Total time in Ray Data iterator initialization code: T\n" f" * Total time user thread is blocked by Ray Data iter_batches: T\n" + f" * Total time spent waiting for the first batch after starting iteration: T\n" f" * Total execution time for user thread: T\n" f"* Batch iteration time breakdown (summed across prefetch threads):\n" f" * In ray.get(): T min, T max, T avg, T total\n" @@ -618,6 +620,7 @@ def test_block_location_nums(ray_start_regular_shared, restore_data_context): f"* Total time overall: T\n" f" * Total time in Ray Data iterator initialization code: T\n" f" * Total time user thread is blocked by Ray Data iter_batches: T\n" + f" * Total time spent waiting for the first batch after starting iteration: T\n" f" * Total execution time for user thread: T\n" f"* Batch iteration time breakdown (summed across prefetch threads):\n" f" * In ray.get(): T min, T max, T avg, T total\n" @@ -1363,6 +1366,7 @@ def test_streaming_stats_full(ray_start_regular_shared, restore_data_context): * Total time overall: T * Total time in Ray Data iterator initialization code: T * Total time user thread is blocked by Ray Data iter_batches: T + * Total time spent waiting for the first batch after starting iteration: T * Total execution time for user thread: T * Batch iteration time breakdown (summed across prefetch threads): * In ray.get(): T min, T max, T avg, T total