diff --git a/test/test_distributed.py b/test/test_distributed.py index 176e0ee84..80fa9f1cc 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -149,6 +149,14 @@ def _test_fullsync(rank, world_size, backend, q): except Exception as e: assert isinstance(e, PrefetchTimeoutError) + # Test that reset/shutdown does not hang while paused + dp3 = dp.fullsync() + it = iter(dp3) + next(it) + dp3.pause() + it2 = iter(dp3) # Reset + next(it2) + _finalize_distributed_queue(rank, q) @world_size_parametrize diff --git a/torchdata/datapipes/iter/util/distributed.py b/torchdata/datapipes/iter/util/distributed.py index ff534bb98..bb1ff1aa7 100644 --- a/torchdata/datapipes/iter/util/distributed.py +++ b/torchdata/datapipes/iter/util/distributed.py @@ -69,8 +69,9 @@ def __init__( self._executor = ThreadPoolExecutor(max_workers=1) self._futures: Deque[Future] = deque() self._lock = threading.RLock() - self._end_flag = False - self._paused = False + self._end_flag: bool = False + self._paused: bool = False + self._is_shutdown: bool = False self._idx = 0 for _ in range(prefetch_size): with self._lock: @@ -93,7 +94,7 @@ def _done_callback_fn(self, index: int, f: Future): if f.exception(): with self._lock: self._end_flag = True - if self.callback_fn is not None: + if self.callback_fn is not None and not self._is_shutdown: self._executor.submit(self.callback_fn, Expected(index, f.exception())) def return_next(self): @@ -104,7 +105,7 @@ def return_next(self): except TimeoutError: raise PrefetchTimeoutError(self.timeout) with self._lock: - if not self._end_flag: + if not self._end_flag and not self._is_shutdown: next_future = self._executor.submit(self.fetch_next) next_future.add_done_callback(partial(self._done_callback_fn, self._idx)) self._futures.append(next_future) @@ -114,6 +115,8 @@ def return_next(self): return data def shutdown(self): + self._paused = False + self._is_shutdown = True self._executor.shutdown(wait=True) def pause(self):