From 40e56e566635c81097ddad974f487e787d3586db Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Fri, 5 May 2023 12:46:05 -0400 Subject: [PATCH 1/3] [DataPipe] Fix FullSync shutdown hanging issue while paused [ghstack-poisoned] --- test/test_distributed.py | 8 ++++++++ torchdata/datapipes/iter/util/distributed.py | 1 + 2 files changed, 9 insertions(+) diff --git a/test/test_distributed.py b/test/test_distributed.py index 176e0ee84..80fa9f1cc 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -149,6 +149,14 @@ def _test_fullsync(rank, world_size, backend, q): except Exception as e: assert isinstance(e, PrefetchTimeoutError) + # Test that reset/shutdown does not hang while paused + dp3 = dp.fullsync() + it = iter(dp3) + next(it) + dp3.pause() + it2 = iter(dp3) # Reset + next(it2) + _finalize_distributed_queue(rank, q) @world_size_parametrize diff --git a/torchdata/datapipes/iter/util/distributed.py b/torchdata/datapipes/iter/util/distributed.py index ff534bb98..57ade7040 100644 --- a/torchdata/datapipes/iter/util/distributed.py +++ b/torchdata/datapipes/iter/util/distributed.py @@ -114,6 +114,7 @@ def return_next(self): return data def shutdown(self): + self._paused = False self._executor.shutdown(wait=True) def pause(self): From 95c5c0b6eef0d5e6011934e0451b1b9a25b31194 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Fri, 5 May 2023 13:01:12 -0400 Subject: [PATCH 2/3] Update on "[DataPipe] Fix FullSync shutdown hanging issue while paused" Differential Revision: [D45610885](https://our.internmc.facebook.com/intern/diff/D45610885) [ghstack-poisoned] --- torchdata/datapipes/iter/util/distributed.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchdata/datapipes/iter/util/distributed.py b/torchdata/datapipes/iter/util/distributed.py index 57ade7040..bb1ff1aa7 100644 --- a/torchdata/datapipes/iter/util/distributed.py +++ b/torchdata/datapipes/iter/util/distributed.py @@ -69,8 +69,9 @@ def __init__( self._executor = ThreadPoolExecutor(max_workers=1) self._futures: Deque[Future] = deque() self._lock = threading.RLock() - self._end_flag = False - self._paused = False + self._end_flag: bool = False + self._paused: bool = False + self._is_shutdown: bool = False self._idx = 0 for _ in range(prefetch_size): with self._lock: @@ -93,7 +94,7 @@ def _done_callback_fn(self, index: int, f: Future): if f.exception(): with self._lock: self._end_flag = True - if self.callback_fn is not None: + if self.callback_fn is not None and not self._is_shutdown: self._executor.submit(self.callback_fn, Expected(index, f.exception())) def return_next(self): @@ -104,7 +105,7 @@ def return_next(self): except TimeoutError: raise PrefetchTimeoutError(self.timeout) with self._lock: - if not self._end_flag: + if not self._end_flag and not self._is_shutdown: next_future = self._executor.submit(self.fetch_next) next_future.add_done_callback(partial(self._done_callback_fn, self._idx)) self._futures.append(next_future) @@ -115,6 +116,7 @@ def return_next(self): def shutdown(self): self._paused = False + self._is_shutdown = True self._executor.shutdown(wait=True) def pause(self): From 84e0bdd7a17e3b5c990c38623790eb9621f91511 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Fri, 5 May 2023 13:19:45 -0400 Subject: [PATCH 3/3] Update on "[DataPipe] Fix FullSync shutdown hanging issue while paused" Before this PR, the executor within FullSync fails to shutdown if it were currently paused. This PR allows shutdown without submitting additional jobs. Differential Revision: [D45610885](https://our.internmc.facebook.com/intern/diff/D45610885) [ghstack-poisoned] --- torchdata/datapipes/iter/util/distributed.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchdata/datapipes/iter/util/distributed.py b/torchdata/datapipes/iter/util/distributed.py index bb1ff1aa7..14a3ab2a8 100644 --- a/torchdata/datapipes/iter/util/distributed.py +++ b/torchdata/datapipes/iter/util/distributed.py @@ -69,9 +69,11 @@ def __init__( self._executor = ThreadPoolExecutor(max_workers=1) self._futures: Deque[Future] = deque() self._lock = threading.RLock() + # `_end_flag` indicates the end of epoch or an exception has been raised, + # with the exception being handled by `callback_fn` self._end_flag: bool = False self._paused: bool = False - self._is_shutdown: bool = False + self._is_shutdown: bool = False # indicates if `_executor` has been shutdown by `shutdown` method self._idx = 0 for _ in range(prefetch_size): with self._lock: @@ -95,6 +97,7 @@ def _done_callback_fn(self, index: int, f: Future): with self._lock: self._end_flag = True if self.callback_fn is not None and not self._is_shutdown: + # Doesn't invoke `callback_fn` if `shutdown` is caleld self._executor.submit(self.callback_fn, Expected(index, f.exception())) def return_next(self):