Skip to content

Commit

Permalink
Change iterator over multiple Queue wrappers to request all protocols…
Browse files Browse the repository at this point in the history
… simulteniously

ghstack-source-id: ed8aae4f86aaaa4157d8803e127ee2151c658b30
Pull Request resolved: #769
  • Loading branch information
VitalyFedyunin committed Sep 9, 2022
1 parent 86df1a0 commit 9e11fd4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
12 changes: 11 additions & 1 deletion torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,17 @@ class NotAvailable(Exception):
class InvalidStateResetRequired(Exception):
"""
Returned by DataPipe when it is expecting to get reset request,
for example RouterDataPipe expecting all workers to request reset'
for example RouterDataPipe expecting all workers to request reset.
"""

pass


class TerminateRequired(Exception):
"""
Returned by DataPipe when it is expecting to get terminate request,
for example it got terminate request from other source and at the process
of stopping.
"""

pass
Expand Down
56 changes: 36 additions & 20 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import functools
import multiprocessing as mp
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional
from typing import Callable, List, Optional

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -102,24 +101,42 @@ def _collate_no_op(batch):
class _IterateQueueDataPipes:
def __init__(self, datapipes):
self.datapipes = datapipes
for dp in self.datapipes:
if not isinstance(dp, communication.iter.QueueWrapper):
raise Exception("Source datapipes should be an instance of iter.QueueWrapper")

def __iter__(self):
# TODO(612): This is slow as it does not sends data requests ahead.
exclude_datapipes: List[Any] = []
while len(exclude_datapipes) < len(self.datapipes):
for dp in self.datapipes:
if dp not in exclude_datapipes:
forever = True
while forever:
try:
value = dp.nonblocking_next()
yield value
forever = False
except StopIteration:
exclude_datapipes.append(dp)
forever = False
except communication.iter.NotAvailable:
time.sleep(0.001)
self.reset()
total_pipes = len(self.datapipes)
disabled_pipe = [False] * len(self.datapipes)
cnt_disabled_pipes = 0

for idx in range(total_pipes):
self.datapipes[idx].protocol.request_next()

while cnt_disabled_pipes < total_pipes:
for idx in range(total_pipes):
if not disabled_pipe[idx]:
response = self.datapipes[idx].protocol.get_response_next(block=True)
if isinstance(response, communication.messages.StopIterationResponse):
disabled_pipe[idx] = True
cnt_disabled_pipes += 1
break
if isinstance(response, communication.messages.InvalidStateResponse):
raise communication.iter.InvalidStateResetRequired
if isinstance(response, communication.messages.TerminateResponse):
raise communication.iter.TerminateRequired
self.datapipes[idx].protocol.request_next()
yield response.value

def reset(self):
# Collect all existing requests results to clear queues
for dp in self.datapipes:
if dp.protocol.waiting_for_response():
dp.protocol.get_response_next(block=True)
# NonBlocking DataPipes do not reset automatically, have to do it manually
for dp in self.datapipes:
dp.reset_iterator()


class PrototypeMultiProcessingReadingService(ReadingServiceInterface):
Expand Down Expand Up @@ -166,8 +183,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
return IterableWrapper(_IterateQueueDataPipes(self.datapipes), deepcopy=False) # type: ignore[return-value]

def initialize_iteration(self) -> None:
for dp in self.datapipes:
dp.reset_iterator()
pass

def __del__(self):
self.finalize()
Expand Down

0 comments on commit 9e11fd4

Please sign in to comment.