diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index d172ee33e428..205b5b3e12f4 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -183,6 +183,8 @@ def add_split_op(dag): self._next_epoch = gen_epochs() self._output_iterator = None + # Used for debugging https://github.com/ray-project/ray/issues/45225 + self._debug_info = {} def stats(self) -> DatasetStats: """Returns stats from the base dataset.""" @@ -249,9 +251,11 @@ def get( def _barrier(self, split_idx: int) -> int: """Arrive and block until the start of the given epoch.""" + self._debug_info[split_idx] = {} # Decrement and await all clients to arrive here. with self._lock: starting_epoch = self._cur_epoch + self._debug_info[split_idx]["starting_epoch"] = starting_epoch self._unfinished_clients_in_epoch -= 1 start_time = time.time() @@ -271,11 +275,31 @@ def _barrier(self, split_idx: int) -> int: time.sleep(0.1) # Advance to the next epoch. + self._debug_info[split_idx]["entering_lock"] = ( + self._cur_epoch, + self._output_iterator is None, + time.time(), + ) with self._lock: + self._debug_info[split_idx]["entered_lock"] = ( + self._cur_epoch, + self._output_iterator is None, + time.time(), + ) if self._cur_epoch == starting_epoch: self._cur_epoch += 1 self._unfinished_clients_in_epoch = self._n self._output_iterator = next(self._next_epoch) + self._debug_info[split_idx]["set_iter"] = ( + self._cur_epoch, + self._output_iterator is None, + time.time(), + ) + self._debug_info[split_idx]["leaving_lock"] = ( + self._cur_epoch, + self._output_iterator is None, + time.time(), + ) - assert self._output_iterator is not None + assert self._output_iterator is not None, self._debug_info return starting_epoch + 1