Skip to content

Commit

Permalink
Change iterator over multiple Queue wrappers to request all protocols…
Browse files Browse the repository at this point in the history
… simulteniously

ghstack-source-id: 09787247b78b4054d16f606070fc00880a0763c9
Pull Request resolved: #769
  • Loading branch information
VitalyFedyunin committed Sep 9, 2022
1 parent 86df1a0 commit 344872d
Showing 1 changed file with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import functools
import multiprocessing as mp
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional
from typing import Callable, List, Optional

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -102,24 +101,42 @@ def _collate_no_op(batch):
class _IterateQueueDataPipes:
def __init__(self, datapipes):
self.datapipes = datapipes
for dp in self.datapipes:
if not isinstance(dp, communication.iter.QueueWrapper):
raise Exception("Source datapipes should be an instance of iter.QueueWrapper")

def __iter__(self):
# TODO(612): This is slow as it does not sends data requests ahead.
exclude_datapipes: List[Any] = []
while len(exclude_datapipes) < len(self.datapipes):
for dp in self.datapipes:
if dp not in exclude_datapipes:
forever = True
while forever:
try:
value = dp.nonblocking_next()
yield value
forever = False
except StopIteration:
exclude_datapipes.append(dp)
forever = False
except communication.iter.NotAvailable:
time.sleep(0.001)
self.reset()
total_pipes = len(self.datapipes)
disabled_pipe = [False] * len(self.datapipes)
cnt_disabled_pipes = 0

for idx in range(total_pipes):
self.datapipes[idx].protocol.request_next()

while cnt_disabled_pipes < total_pipes:
for idx in range(total_pipes):
if not disabled_pipe[idx]:
response = self.datapipes[idx].protocol.get_response_next(block=True)
if isinstance(response, communication.messages.StopIterationResponse):
disabled_pipe[idx] = True
cnt_disabled_pipes += 1
break
if isinstance(response, communication.messages.InvalidStateResponse):
raise communication.iter.InvalidStateResetRequired
if isinstance(response, communication.messages.TerminateResponse):
raise communication.iter.TerminateRequired
self.datapipes[idx].protocol.request_next()
yield response.value

def reset(self):
# Collect all existing requests results to clear queues
for dp in self.datapipes:
if dp.protocol.waiting_for_response():
dp.protocol.get_response_next(block=True)
# NonBlocking DataPipes do not reset automatically, have to do it manually
for dp in self.datapipes:
dp.reset_iterator()


class PrototypeMultiProcessingReadingService(ReadingServiceInterface):
Expand Down Expand Up @@ -166,8 +183,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
return IterableWrapper(_IterateQueueDataPipes(self.datapipes), deepcopy=False) # type: ignore[return-value]

def initialize_iteration(self) -> None:
for dp in self.datapipes:
dp.reset_iterator()
pass

def __del__(self):
self.finalize()
Expand Down

0 comments on commit 344872d

Please sign in to comment.