-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change iterator over multiple Queue wrappers to request all protocols simulteniously #769
Changes from all commits
d56bd3e
c0338d3
79c19df
d27e3b9
861d9f0
753b62f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,11 +7,11 @@ | |
|
||
import functools | ||
import multiprocessing as mp | ||
import time | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from datetime import timedelta | ||
from typing import Any, Callable, List, Optional | ||
from typing import Callable, List, Optional | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
@@ -21,7 +21,7 @@ | |
from torchdata._constants import default_timeout_in_s | ||
from torchdata.dataloader2 import communication | ||
from torchdata.dataloader2.graph import DataPipe | ||
from torchdata.datapipes.iter import FullSync, IterableWrapper | ||
from torchdata.datapipes.iter import FullSync, IterableWrapper, IterDataPipe | ||
|
||
|
||
class ReadingServiceInterface(ABC): | ||
|
@@ -104,27 +104,46 @@ def _collate_no_op(batch): | |
return batch[0] | ||
|
||
|
||
class _IterateQueueDataPipes: | ||
class _IterateQueueDataPipes(IterDataPipe): | ||
def __init__(self, datapipes): | ||
# TODO(VitalyFedyunin): Consider combining _IterateQueueDataPipes and QueueWrapper | ||
# into one class, which supports any number of queues. | ||
self.datapipes = datapipes | ||
for dp in self.datapipes: | ||
if not isinstance(dp, communication.iter.QueueWrapper): | ||
raise Exception("Source datapipes should be an instance of iter.QueueWrapper") | ||
|
||
def __iter__(self): | ||
# TODO(612): This is slow as it does not sends data requests ahead. | ||
exclude_datapipes: List[Any] = [] | ||
while len(exclude_datapipes) < len(self.datapipes): | ||
for dp in self.datapipes: | ||
if dp not in exclude_datapipes: | ||
forever = True | ||
while forever: | ||
try: | ||
value = dp.nonblocking_next() | ||
yield value | ||
forever = False | ||
except StopIteration: | ||
exclude_datapipes.append(dp) | ||
forever = False | ||
except communication.iter.NotAvailable: | ||
time.sleep(0.001) | ||
total_pipes = len(self.datapipes) | ||
disabled_pipe = [False] * len(self.datapipes) | ||
cnt_disabled_pipes = 0 | ||
|
||
for idx in range(total_pipes): | ||
self.datapipes[idx].protocol.request_next() | ||
Comment on lines
+121
to
+122
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the main review comment, the rests are some nits. Might be a noob question. We now request and receive data from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. However, I'm still considering the possibility of merging _IterateQueueDataPipes and QueueWrapper to make it one class that supports 1:M queues. |
||
|
||
while cnt_disabled_pipes < total_pipes: | ||
for idx in range(total_pipes): | ||
if not disabled_pipe[idx]: | ||
response = self.datapipes[idx].protocol.get_response_next(block=True) | ||
if isinstance(response, communication.messages.StopIterationResponse): | ||
disabled_pipe[idx] = True | ||
cnt_disabled_pipes += 1 | ||
continue | ||
if isinstance(response, communication.messages.InvalidStateResponse): | ||
raise communication.iter.InvalidStateResetRequired | ||
if isinstance(response, communication.messages.TerminateResponse): | ||
raise communication.iter.TerminateRequired | ||
Comment on lines
+132
to
+135
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: shouldn't these be caught by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, because I'm not using next of QueueWrapper, but instead accessing protocols next directly. |
||
self.datapipes[idx].protocol.request_next() | ||
yield response.value | ||
|
||
def reset(self): | ||
# Collect all existing requests results to clear queues | ||
for dp in self.datapipes: | ||
if dp.protocol.waiting_for_response(): | ||
dp.protocol.get_response_next(block=True) | ||
# NonBlocking DataPipes do not reset automatically, have to do it manually | ||
for dp in self.datapipes: | ||
dp.reset_iterator() | ||
|
||
|
||
class PrototypeMultiProcessingReadingService(ReadingServiceInterface): | ||
|
@@ -168,11 +187,10 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: | |
) | ||
self.datapipes.append(local_datapipe) | ||
|
||
return IterableWrapper(_IterateQueueDataPipes(self.datapipes), deepcopy=False) # type: ignore[return-value] | ||
return _IterateQueueDataPipes(self.datapipes) # type: ignore[return-value] | ||
|
||
def initialize_iteration(self) -> None: | ||
for dp in self.datapipes: | ||
dp.reset_iterator() | ||
pass | ||
|
||
def __del__(self): | ||
self.finalize() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, It might be better to add an Error message in
__init__
.super().__init__(msg)
.