Skip to content

Commit

Permalink
Removing delegation for 'pause', 'limit', and 'resume' (#1011)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1011

Test Plan: Imported from OSS

Reviewed By: mingyuzh, ejguan

Differential Revision: D43251818

Pulled By: NivekT

fbshipit-source-id: 5c34e1a71438308366b473c5d4d075a8158088f1
  • Loading branch information
NivekT authored and ejguan committed Feb 28, 2023
1 parent 6ace7a4 commit b57545f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
23 changes: 11 additions & 12 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def resume(self) -> None:
Restarts the threads within ``DataLoader2`` and allows it to yield additional batches.
"""
self.dataloader._resume()
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "resume"):
self.dataloader._datapipe_iter.resume() # type: ignore[attr-defined]

def limit(self, num_batches: Optional[int]) -> None:
"""
Expand All @@ -120,8 +118,7 @@ def limit(self, num_batches: Optional[int]) -> None:
"""
self.limit_counter = 0
self.limit_threshold = num_batches
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "limit"):
self.dataloader._datapipe_iter.limit(num_batches) # type: ignore[attr-defined]
self.dataloader._limit(num_batches)

def __getattr__(self, name):
"""
Expand Down Expand Up @@ -339,11 +336,8 @@ def _pause(self):
if hasattr(self.reading_service, "_pause"):
self._is_paused = True
self.reading_service._pause()
# TODO: the condition should be `else` once `self._datapipe_iter.pause/limit()` is no longer used
elif self._datapipe_iter is None or not (
hasattr(self._datapipe_iter, "limit") or hasattr(self._datapipe_iter, "pause")
):
warnings.warn("ReadingService doesn't support pause.")
else:
warnings.warn("ReadingService doesn't support `pause`.")

def _resume(self):
if hasattr(self.reading_service, "_resume"):
Expand All @@ -352,6 +346,11 @@ def _resume(self):
else:
self.reading_service._resume()
self._is_paused = False
# TODO: the condition should be `else` once `self._datapipe_iter.resume()` is no longer used
elif self._datapipe_iter is None or not hasattr(self._datapipe_iter, "resume"):
warnings.warn("ReadingService doesn't support resume.")
else:
warnings.warn("ReadingService doesn't support `resume`.")

def _limit(self, num_batches: Optional[int]) -> None:
if hasattr(self.reading_service, "_limit"):
self.reading_service._limit(num_batches)
else:
warnings.warn("ReadingService doesn't support `limit`.")
7 changes: 7 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,13 @@ def _resume(self):
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
self._main_prefetch_datapipe.resume() # type: ignore[union-attr]

def _limit(self, num_batches: Optional[int]) -> None:
"""
For this ReadingService, `DataLoader2Iterator` and `DataLoader2` should sufficiently handle
the limit operation, such that nothing needs to be done here.
"""
pass


class DistributedReadingService(ReadingServiceInterface):
r"""
Expand Down

0 comments on commit b57545f

Please sign in to comment.