Skip to content
Merged
21 changes: 18 additions & 3 deletions python/ray/data/_internal/block_batching/iter_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(
if actor_prefetcher_enabled
else WaitBlockPrefetcher()
)
self._yielded_first_batch = False

def _prefetch_blocks(
self, ref_bundles: Iterator[RefBundle]
Expand Down Expand Up @@ -235,15 +236,29 @@ def __iter__(self) -> Iterator[DataBatch]:
return self._iter_batches()

def before_epoch_start(self):
pass
self._yielded_first_batch = False

def after_epoch_end(self):
StatsManager.clear_iteration_metrics(self._dataset_tag)

@contextmanager
def get_next_batch_context(self):
with self._stats.iter_total_blocked_s.timer() if self._stats else nullcontext():
yield
try:
if self._stats:
# Always track total blocked time
total_timer = self._stats.iter_total_blocked_s.timer()
# Also track the time until the first batch is ready
first_batch_ready_timer = (
self._stats.iter_time_to_first_batch_s.timer()
if not self._yielded_first_batch
else nullcontext()
)
with total_timer, first_batch_ready_timer:
yield
else:
yield
finally:
self._yielded_first_batch = True

@contextmanager
def yield_batch_context(self, batch: Batch):
Expand Down
17 changes: 17 additions & 0 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ def __init__(self, max_stats=1000):
description="Seconds user thread is blocked by iter_batches()",
tag_keys=iter_tag_keys,
)
self.time_to_first_batch_s = Gauge(
"data_iter_time_to_first_batch_seconds",
description="Total time spent waiting for the first batch after starting iteration. "
"This includes the dataset pipeline warmup time. This metric is accumulated across different epochs.",
tag_keys=iter_tag_keys,
)
self.iter_user_s = Gauge(
"data_iter_user_seconds",
description="Seconds spent in user code",
Expand Down Expand Up @@ -469,6 +475,7 @@ def update_iteration_metrics(
):
tags = self._create_tags(dataset_tag)
self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags)
self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags)
self.iter_user_s.set(stats.iter_user_s.get(), tags)
self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags)

Expand Down Expand Up @@ -948,6 +955,7 @@ def __init__(
self.iter_format_batch_s: Timer = Timer()
self.iter_collate_batch_s: Timer = Timer()
self.iter_finalize_batch_s: Timer = Timer()
self.iter_time_to_first_batch_s: Timer = Timer()
self.iter_total_blocked_s: Timer = Timer()
self.iter_user_s: Timer = Timer()
self.iter_initialize_s: Timer = Timer()
Expand Down Expand Up @@ -1003,6 +1011,7 @@ def to_summary(self) -> "DatasetStatsSummary":
self.iter_format_batch_s,
self.iter_collate_batch_s,
self.iter_finalize_batch_s,
self.iter_time_to_first_batch_s,
self.iter_total_blocked_s,
self.iter_user_s,
self.iter_initialize_s,
Expand Down Expand Up @@ -1642,6 +1651,8 @@ class IterStatsSummary:
collate_time: Timer
# Time spent in finalize_fn, in seconds
finalize_batch_time: Timer
# Time user thread is blocked waiting for first batch
time_to_first_batch: Timer
# Total time user thread is blocked by iter_batches
block_time: Timer
# Time spent in user code, in seconds
Expand All @@ -1665,6 +1676,7 @@ def to_string(self) -> str:
out = ""
if (
self.block_time.get()
or self.time_to_first_batch.get()
or self.total_time.get()
or self.get_time.get()
or self.next_time.get()
Expand All @@ -1685,6 +1697,11 @@ def to_string(self) -> str:
" * Total time user thread is blocked by Ray Data iter_batches: "
"{}\n".format(fmt(self.block_time.get()))
)
if self.time_to_first_batch.get():
out += (
" * Total time spent waiting for the first batch after starting iteration: "
"{}\n".format(fmt(self.time_to_first_batch.get()))
)
if self.user_time.get():
out += " * Total execution time for user thread: {}\n".format(
fmt(self.user_time.get())
Expand Down
4 changes: 4 additions & 0 deletions python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def test_streaming_split_stats(ray_start_regular_shared, restore_data_context):
* Total time overall: T
* Total time in Ray Data iterator initialization code: T
* Total time user thread is blocked by Ray Data iter_batches: T
* Total time spent waiting for the first batch after starting iteration: T
* Total execution time for user thread: T
* Batch iteration time breakdown (summed across prefetch threads):
* In ray.get(): T min, T max, T avg, T total
Expand Down Expand Up @@ -577,6 +578,7 @@ def test_dataset_stats_basic(
f"* Total time overall: T\n"
f" * Total time in Ray Data iterator initialization code: T\n"
f" * Total time user thread is blocked by Ray Data iter_batches: T\n"
f" * Total time spent waiting for the first batch after starting iteration: T\n"
f" * Total execution time for user thread: T\n"
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
f" * In ray.get(): T min, T max, T avg, T total\n"
Expand Down Expand Up @@ -618,6 +620,7 @@ def test_block_location_nums(ray_start_regular_shared, restore_data_context):
f"* Total time overall: T\n"
f" * Total time in Ray Data iterator initialization code: T\n"
f" * Total time user thread is blocked by Ray Data iter_batches: T\n"
f" * Total time spent waiting for the first batch after starting iteration: T\n"
f" * Total execution time for user thread: T\n"
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
f" * In ray.get(): T min, T max, T avg, T total\n"
Expand Down Expand Up @@ -1363,6 +1366,7 @@ def test_streaming_stats_full(ray_start_regular_shared, restore_data_context):
* Total time overall: T
* Total time in Ray Data iterator initialization code: T
* Total time user thread is blocked by Ray Data iter_batches: T
* Total time spent waiting for the first batch after starting iteration: T
* Total execution time for user thread: T
* Batch iteration time breakdown (summed across prefetch threads):
* In ray.get(): T min, T max, T avg, T total
Expand Down