diff --git a/test/test_distributed.py b/test/test_distributed.py index 176e0ee84..80fa9f1cc 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -149,6 +149,14 @@ def _test_fullsync(rank, world_size, backend, q): except Exception as e: assert isinstance(e, PrefetchTimeoutError) + # Test that reset/shutdown does not hang while paused + dp3 = dp.fullsync() + it = iter(dp3) + next(it) + dp3.pause() + it2 = iter(dp3) # Reset + next(it2) + _finalize_distributed_queue(rank, q) @world_size_parametrize diff --git a/torchdata/datapipes/iter/util/distributed.py b/torchdata/datapipes/iter/util/distributed.py index ff534bb98..57ade7040 100644 --- a/torchdata/datapipes/iter/util/distributed.py +++ b/torchdata/datapipes/iter/util/distributed.py @@ -114,6 +114,7 @@ def return_next(self): return data def shutdown(self): + self._paused = False self._executor.shutdown(wait=True) def pause(self):