diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index 4b24dc687..5b1359dbe 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -34,7 +34,17 @@ class NotAvailable(Exception): class InvalidStateResetRequired(Exception): """ Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset' + for example RouterDataPipe expecting all workers to request reset. + """ + + pass + + +class TerminateRequired(Exception): + """ + Returned by DataPipe when it is expecting to get terminate request, + for example it got terminate request from other source and at the process + of stopping. """ pass diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index e144e9027..01187611f 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -7,9 +7,8 @@ import functools import multiprocessing as mp -import time from abc import ABC, abstractmethod -from typing import Any, Callable, List, Optional +from typing import Callable, List, Optional import torch from torch.utils.data import DataLoader @@ -102,24 +101,42 @@ def _collate_no_op(batch): class _IterateQueueDataPipes: def __init__(self, datapipes): self.datapipes = datapipes + for dp in self.datapipes: + if not isinstance(dp, communication.iter.QueueWrapper): + raise Exception("Source datapipes should be an instance of iter.QueueWrapper") def __iter__(self): - # TODO(612): This is slow as it does not sends data requests ahead. - exclude_datapipes: List[Any] = [] - while len(exclude_datapipes) < len(self.datapipes): - for dp in self.datapipes: - if dp not in exclude_datapipes: - forever = True - while forever: - try: - value = dp.nonblocking_next() - yield value - forever = False - except StopIteration: - exclude_datapipes.append(dp) - forever = False - except communication.iter.NotAvailable: - time.sleep(0.001) + self.reset() + total_pipes = len(self.datapipes) + disabled_pipe = [False] * len(self.datapipes) + cnt_disabled_pipes = 0 + + for idx in range(total_pipes): + self.datapipes[idx].protocol.request_next() + + while cnt_disabled_pipes < total_pipes: + for idx in range(total_pipes): + if not disabled_pipe[idx]: + response = self.datapipes[idx].protocol.get_response_next(block=True) + if isinstance(response, communication.messages.StopIterationResponse): + disabled_pipe[idx] = True + cnt_disabled_pipes += 1 + break + if isinstance(response, communication.messages.InvalidStateResponse): + raise communication.iter.InvalidStateResetRequired + if isinstance(response, communication.messages.TerminateResponse): + raise communication.iter.TerminateRequired + self.datapipes[idx].protocol.request_next() + yield response.value + + def reset(self): + # Collect all existing requests results to clear queues + for dp in self.datapipes: + if dp.protocol.waiting_for_response(): + dp.protocol.get_response_next(block=True) + # NonBlocking DataPipes do not reset automatically, have to do it manually + for dp in self.datapipes: + dp.reset_iterator() class PrototypeMultiProcessingReadingService(ReadingServiceInterface): @@ -166,8 +183,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: return IterableWrapper(_IterateQueueDataPipes(self.datapipes), deepcopy=False) # type: ignore[return-value] def initialize_iteration(self) -> None: - for dp in self.datapipes: - dp.reset_iterator() + pass def __del__(self): self.finalize()