Skip to content

Commit

Permalink
Fix FullSync shutdown hanging issue while paused (#1153)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1153

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D45610885

Pulled By: NivekT

fbshipit-source-id: c7e5d6675664ec44ea5d73f7464a1f43ad1d9282
  • Loading branch information
NivekT authored and facebook-github-bot committed May 15, 2023
1 parent 65e2ede commit 0b39117
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
8 changes: 8 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ def _test_fullsync(rank, world_size, backend, q):
except Exception as e:
assert isinstance(e, PrefetchTimeoutError)

# Test that reset/shutdown does not hang while paused
dp3 = dp.fullsync()
it = iter(dp3)
next(it)
dp3.pause()
it2 = iter(dp3) # Reset
next(it2)

_finalize_distributed_queue(rank, q)

@world_size_parametrize
Expand Down
27 changes: 20 additions & 7 deletions torchdata/datapipes/iter/util/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,11 @@ def __init__(
self._executor = ThreadPoolExecutor(max_workers=1)
self._futures: Deque[Future] = deque()
self._lock = threading.RLock()
self._end_flag = False
self._paused = False
# `_end_flag` indicates the end of epoch or an exception has been raised,
# with the exception being handled by `callback_fn`
self._end_flag: bool = False
self._paused: bool = False
self._is_shutdown: bool = False # indicates if `_executor` has been shutdown by `shutdown` method
self._idx = 0
for _ in range(prefetch_size):
with self._lock:
Expand All @@ -85,15 +88,14 @@ def __init__(
def fetch_next(self):
while self._paused:
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)

res = next(self.datapipe_iterator)
return res
return next(self.datapipe_iterator)

def _done_callback_fn(self, index: int, f: Future):
if f.exception():
with self._lock:
self._end_flag = True
if self.callback_fn is not None:
if self.callback_fn is not None and not self._is_shutdown:
# Doesn't invoke `callback_fn` if `shutdown` is caleld
self._executor.submit(self.callback_fn, Expected(index, f.exception()))

def return_next(self):
Expand All @@ -104,7 +106,7 @@ def return_next(self):
except TimeoutError:
raise PrefetchTimeoutError(self.timeout)
with self._lock:
if not self._end_flag:
if not self._end_flag and not self._is_shutdown:
next_future = self._executor.submit(self.fetch_next)
next_future.add_done_callback(partial(self._done_callback_fn, self._idx))
self._futures.append(next_future)
Expand All @@ -114,6 +116,8 @@ def return_next(self):
return data

def shutdown(self):
self._paused = False
self._is_shutdown = True
self._executor.shutdown(wait=True)

def pause(self):
Expand Down Expand Up @@ -259,3 +263,12 @@ def pause(self):
def resume(self):
if self._executor is not None:
self._executor.resume()

@final
def shutdown(self):
if self._executor is not None:
self._executor.shutdown()
self._executor = None

def __del__(self):
self.shutdown()

0 comments on commit 0b39117

Please sign in to comment.