diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index 9e98bbf6a..1154a10ee 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -100,8 +100,6 @@ def resume(self) -> None: Restarts the threads within ``DataLoader2`` and allows it to yield additional batches. """ self.dataloader._resume() - if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "resume"): - self.dataloader._datapipe_iter.resume() # type: ignore[attr-defined] def limit(self, num_batches: Optional[int]) -> None: """ @@ -120,8 +118,7 @@ def limit(self, num_batches: Optional[int]) -> None: """ self.limit_counter = 0 self.limit_threshold = num_batches - if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "limit"): - self.dataloader._datapipe_iter.limit(num_batches) # type: ignore[attr-defined] + self.dataloader._limit(num_batches) def __getattr__(self, name): """ @@ -339,11 +336,8 @@ def _pause(self): if hasattr(self.reading_service, "_pause"): self._is_paused = True self.reading_service._pause() - # TODO: the condition should be `else` once `self._datapipe_iter.pause/limit()` is no longer used - elif self._datapipe_iter is None or not ( - hasattr(self._datapipe_iter, "limit") or hasattr(self._datapipe_iter, "pause") - ): - warnings.warn("ReadingService doesn't support pause.") + else: + warnings.warn("ReadingService doesn't support `pause`.") def _resume(self): if hasattr(self.reading_service, "_resume"): @@ -352,6 +346,11 @@ def _resume(self): else: self.reading_service._resume() self._is_paused = False - # TODO: the condition should be `else` once `self._datapipe_iter.resume()` is no longer used - elif self._datapipe_iter is None or not hasattr(self._datapipe_iter, "resume"): - warnings.warn("ReadingService doesn't support resume.") + else: + warnings.warn("ReadingService doesn't support `resume`.") + + def _limit(self, num_batches: Optional[int]) -> None: + if hasattr(self.reading_service, "_limit"): + self.reading_service._limit(num_batches) + else: + warnings.warn("ReadingService doesn't support `limit`.") diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 934b468c5..9b4143c08 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -406,6 +406,13 @@ def _resume(self): if self.main_prefetch_cnt > 0 and self.num_workers > 0: self._main_prefetch_datapipe.resume() # type: ignore[union-attr] + def _limit(self, num_batches: Optional[int]) -> None: + """ + For this ReadingService, `DataLoader2Iterator` and `DataLoader2` should sufficiently handle + the limit operation, such that nothing needs to be done here. + """ + pass + class DistributedReadingService(ReadingServiceInterface): r"""