Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change iterator over multiple Queue wrappers to request all protocols simulteniously #769

Closed
12 changes: 11 additions & 1 deletion torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,17 @@ class NotAvailable(Exception):
class InvalidStateResetRequired(Exception):
"""
Returned by DataPipe when it is expecting to get reset request,
for example RouterDataPipe expecting all workers to request reset'
for example RouterDataPipe expecting all workers to request reset.
"""

pass
Comment on lines +39 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, It might be better to add an Error message in __init__.
super().__init__(msg).



class TerminateRequired(Exception):
"""
Returned by DataPipe when it is expecting to get terminate request,
for example it got terminate request from other source and at the process
of stopping.
"""

pass
Expand Down
64 changes: 41 additions & 23 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

import functools
import multiprocessing as mp
import time

from abc import ABC, abstractmethod

from datetime import timedelta
from typing import Any, Callable, List, Optional
from typing import Callable, List, Optional

import torch
import torch.distributed as dist
Expand All @@ -21,7 +21,7 @@
from torchdata._constants import default_timeout_in_s
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe
from torchdata.datapipes.iter import FullSync, IterableWrapper
from torchdata.datapipes.iter import FullSync, IterableWrapper, IterDataPipe


class ReadingServiceInterface(ABC):
Expand Down Expand Up @@ -104,27 +104,46 @@ def _collate_no_op(batch):
return batch[0]


class _IterateQueueDataPipes:
class _IterateQueueDataPipes(IterDataPipe):
def __init__(self, datapipes):
# TODO(VitalyFedyunin): Consider combining _IterateQueueDataPipes and QueueWrapper
# into one class, which supports any number of queues.
self.datapipes = datapipes
for dp in self.datapipes:
if not isinstance(dp, communication.iter.QueueWrapper):
raise Exception("Source datapipes should be an instance of iter.QueueWrapper")

def __iter__(self):
# TODO(612): This is slow as it does not sends data requests ahead.
exclude_datapipes: List[Any] = []
while len(exclude_datapipes) < len(self.datapipes):
for dp in self.datapipes:
if dp not in exclude_datapipes:
forever = True
while forever:
try:
value = dp.nonblocking_next()
yield value
forever = False
except StopIteration:
exclude_datapipes.append(dp)
forever = False
except communication.iter.NotAvailable:
time.sleep(0.001)
total_pipes = len(self.datapipes)
disabled_pipe = [False] * len(self.datapipes)
cnt_disabled_pipes = 0

for idx in range(total_pipes):
self.datapipes[idx].protocol.request_next()
Comment on lines +121 to +122
Copy link
Contributor

@ejguan ejguan Sep 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main review comment, the rests are some nits.

Might be a noob question. We now request and receive data from protocol object. Then, do we still need QueueWrapper? We can directly let _IterateQueueDataPipes store a list of protocol clients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QueueWrapper handles terminations (and snapshotting in the future). Direct access to protocol here is only required to reorder traversals.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I'm still considering the possibility of merging _IterateQueueDataPipes and QueueWrapper to make it one class that supports 1:M queues.


while cnt_disabled_pipes < total_pipes:
for idx in range(total_pipes):
if not disabled_pipe[idx]:
response = self.datapipes[idx].protocol.get_response_next(block=True)
if isinstance(response, communication.messages.StopIterationResponse):
disabled_pipe[idx] = True
cnt_disabled_pipes += 1
continue
if isinstance(response, communication.messages.InvalidStateResponse):
raise communication.iter.InvalidStateResetRequired
if isinstance(response, communication.messages.TerminateResponse):
raise communication.iter.TerminateRequired
Comment on lines +132 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: shouldn't these be caught by QueueWrapper's method nonblocking_next?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because I'm not using next of QueueWrapper, but instead accessing protocols next directly.

self.datapipes[idx].protocol.request_next()
yield response.value

def reset(self):
# Collect all existing requests results to clear queues
for dp in self.datapipes:
if dp.protocol.waiting_for_response():
dp.protocol.get_response_next(block=True)
# NonBlocking DataPipes do not reset automatically, have to do it manually
for dp in self.datapipes:
dp.reset_iterator()


class PrototypeMultiProcessingReadingService(ReadingServiceInterface):
Expand Down Expand Up @@ -168,11 +187,10 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
)
self.datapipes.append(local_datapipe)

return IterableWrapper(_IterateQueueDataPipes(self.datapipes), deepcopy=False) # type: ignore[return-value]
return _IterateQueueDataPipes(self.datapipes) # type: ignore[return-value]

def initialize_iteration(self) -> None:
for dp in self.datapipes:
dp.reset_iterator()
pass

def __del__(self):
self.finalize()
Expand Down