diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index 4b24dc687..5b1359dbe 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -34,7 +34,17 @@ class NotAvailable(Exception): class InvalidStateResetRequired(Exception): """ Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset' + for example RouterDataPipe expecting all workers to request reset. + """ + + pass + + +class TerminateRequired(Exception): + """ + Returned by DataPipe when it is expecting to get terminate request, + for example it got terminate request from other source and at the process + of stopping. """ pass diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index edd734842..10860fccb 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -7,11 +7,11 @@ import functools import multiprocessing as mp -import time from abc import ABC, abstractmethod + from datetime import timedelta -from typing import Any, Callable, List, Optional +from typing import Callable, List, Optional import torch import torch.distributed as dist @@ -21,7 +21,7 @@ from torchdata._constants import default_timeout_in_s from torchdata.dataloader2 import communication from torchdata.dataloader2.graph import DataPipe -from torchdata.datapipes.iter import FullSync, IterableWrapper +from torchdata.datapipes.iter import FullSync, IterableWrapper, IterDataPipe class ReadingServiceInterface(ABC): @@ -104,27 +104,46 @@ def _collate_no_op(batch): return batch[0] -class _IterateQueueDataPipes: +class _IterateQueueDataPipes(IterDataPipe): def __init__(self, datapipes): + # TODO(VitalyFedyunin): Consider combining _IterateQueueDataPipes and QueueWrapper + # into one class, which supports any number of queues. self.datapipes = datapipes + for dp in self.datapipes: + if not isinstance(dp, communication.iter.QueueWrapper): + raise Exception("Source datapipes should be an instance of iter.QueueWrapper") def __iter__(self): - # TODO(612): This is slow as it does not sends data requests ahead. - exclude_datapipes: List[Any] = [] - while len(exclude_datapipes) < len(self.datapipes): - for dp in self.datapipes: - if dp not in exclude_datapipes: - forever = True - while forever: - try: - value = dp.nonblocking_next() - yield value - forever = False - except StopIteration: - exclude_datapipes.append(dp) - forever = False - except communication.iter.NotAvailable: - time.sleep(0.001) + total_pipes = len(self.datapipes) + disabled_pipe = [False] * len(self.datapipes) + cnt_disabled_pipes = 0 + + for idx in range(total_pipes): + self.datapipes[idx].protocol.request_next() + + while cnt_disabled_pipes < total_pipes: + for idx in range(total_pipes): + if not disabled_pipe[idx]: + response = self.datapipes[idx].protocol.get_response_next(block=True) + if isinstance(response, communication.messages.StopIterationResponse): + disabled_pipe[idx] = True + cnt_disabled_pipes += 1 + continue + if isinstance(response, communication.messages.InvalidStateResponse): + raise communication.iter.InvalidStateResetRequired + if isinstance(response, communication.messages.TerminateResponse): + raise communication.iter.TerminateRequired + self.datapipes[idx].protocol.request_next() + yield response.value + + def reset(self): + # Collect all existing requests results to clear queues + for dp in self.datapipes: + if dp.protocol.waiting_for_response(): + dp.protocol.get_response_next(block=True) + # NonBlocking DataPipes do not reset automatically, have to do it manually + for dp in self.datapipes: + dp.reset_iterator() class PrototypeMultiProcessingReadingService(ReadingServiceInterface): @@ -168,11 +187,10 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: ) self.datapipes.append(local_datapipe) - return IterableWrapper(_IterateQueueDataPipes(self.datapipes), deepcopy=False) # type: ignore[return-value] + return _IterateQueueDataPipes(self.datapipes) # type: ignore[return-value] def initialize_iteration(self) -> None: - for dp in self.datapipes: - dp.reset_iterator() + pass def __del__(self): self.finalize()