Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataPipe] Fix FullSync shutdown hanging issue while paused #1153

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ def _test_fullsync(rank, world_size, backend, q):
except Exception as e:
assert isinstance(e, PrefetchTimeoutError)

# Test that reset/shutdown does not hang while paused
dp3 = dp.fullsync()
it = iter(dp3)
next(it)
dp3.pause()
it2 = iter(dp3) # Reset
next(it2)
NivekT marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test failing without the patch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for DataLoader2 with fullsync?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it fails without the patch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't really add a DataLoader2 test because DistributedRS currently doesn't support pause....

I will have to add that separately. Let me know if I should land this as it is or add that on top of this.

_finalize_distributed_queue(rank, q)

@world_size_parametrize
Expand Down
27 changes: 20 additions & 7 deletions torchdata/datapipes/iter/util/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,11 @@ def __init__(
self._executor = ThreadPoolExecutor(max_workers=1)
self._futures: Deque[Future] = deque()
self._lock = threading.RLock()
self._end_flag = False
self._paused = False
# `_end_flag` indicates the end of epoch or an exception has been raised,
# with the exception being handled by `callback_fn`
self._end_flag: bool = False
self._paused: bool = False
self._is_shutdown: bool = False # indicates if `_executor` has been shutdown by `shutdown` method
self._idx = 0
for _ in range(prefetch_size):
with self._lock:
Expand All @@ -85,15 +88,14 @@ def __init__(
def fetch_next(self):
while self._paused:
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)

res = next(self.datapipe_iterator)
return res
return next(self.datapipe_iterator)

def _done_callback_fn(self, index: int, f: Future):
if f.exception():
with self._lock:
self._end_flag = True
if self.callback_fn is not None:
if self.callback_fn is not None and not self._is_shutdown:
NivekT marked this conversation as resolved.
Show resolved Hide resolved
# Doesn't invoke `callback_fn` if `shutdown` is caleld
self._executor.submit(self.callback_fn, Expected(index, f.exception()))

def return_next(self):
Expand All @@ -104,7 +106,7 @@ def return_next(self):
except TimeoutError:
raise PrefetchTimeoutError(self.timeout)
with self._lock:
if not self._end_flag:
if not self._end_flag and not self._is_shutdown:
next_future = self._executor.submit(self.fetch_next)
next_future.add_done_callback(partial(self._done_callback_fn, self._idx))
self._futures.append(next_future)
Expand All @@ -114,6 +116,8 @@ def return_next(self):
return data

def shutdown(self):
self._paused = False
self._is_shutdown = True
self._executor.shutdown(wait=True)

def pause(self):
Expand Down Expand Up @@ -259,3 +263,12 @@ def pause(self):
def resume(self):
if self._executor is not None:
self._executor.resume()

@final
def shutdown(self):
if self._executor is not None:
self._executor.shutdown()
self._executor = None

def __del__(self):
self.shutdown()