Skip to content

Commit

Permalink
[DataLoader2] Implementing single iterator constraint
Browse files Browse the repository at this point in the history
ghstack-source-id: b84c6c7ae0be647e53315fffc86964ae1f2eba52
Pull Request resolved: #700
  • Loading branch information
NivekT committed Aug 1, 2022
1 parent a7745b9 commit fb1f63d
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 34 deletions.
72 changes: 63 additions & 9 deletions test/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import unittest
from unittest import TestCase

from torch.utils.data.graph import DataPipe

from torchdata.dataloader2 import (
DataLoader2,
MultiProcessingReadingService,
Expand All @@ -19,15 +21,31 @@
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


class _ReadingServiceWrapper:
def __init__(self, dp):
self.dp = dp

def __iter__(self):
self.it = iter(self.dp)
return self

def __next__(self):
return next(self.it)

@staticmethod
def return_one():
return 1


class TestReadingService(ReadingServiceInterface):
def initialize(self, dp: IterDataPipe) -> IterDataPipe:
return dp
def initialize(self, dp: DataPipe) -> DataPipe:
return _ReadingServiceWrapper(dp) # type: ignore[return-value]


class DataLoader2Test(TestCase):
def test_dataloader2(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader = DataLoader2(datapipe=test_data_pipe)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)

expected_batch = 0
for batch in iter(data_loader):
Expand All @@ -36,12 +54,12 @@ def test_dataloader2(self) -> None:

def test_dataloader2_shutdown(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader = DataLoader2(datapipe=test_data_pipe)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
data_loader.shutdown()

def test_dataloader2_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader = DataLoader2(datapipe=test_data_pipe)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)

state = data_loader.state_dict()
self.assertIsNotNone(state)
Expand All @@ -52,7 +70,7 @@ def test_dataloader2_state_dict(self) -> None:
def test_dataloader2_reading_service(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = TestReadingService()
data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

expected_batch = 0
for batch in iter(data_loader):
Expand All @@ -62,7 +80,7 @@ def test_dataloader2_reading_service(self) -> None:
def test_dataloader2_multi_process_reading_service(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = MultiProcessingReadingService()
data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

expected_batch = 0
for batch in iter(data_loader):
Expand All @@ -72,7 +90,7 @@ def test_dataloader2_multi_process_reading_service(self) -> None:
def test_dataloader2_load_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = TestReadingService()
data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

batch = next(iter(data_loader))
self.assertEqual(batch, 0)
Expand All @@ -84,7 +102,7 @@ def test_dataloader2_load_state_dict(self) -> None:
data_loader.shutdown()

test_data_pipe_2 = IterableWrapper(range(5))
restored_data_loader = DataLoader2(datapipe=test_data_pipe_2, reading_service=reading_service)
restored_data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe_2, reading_service=reading_service)
restored_data_loader.load_state_dict(state)

restored_data_loader_datapipe = restored_data_loader.datapipe
Expand All @@ -104,6 +122,42 @@ def test_dataloader2_load_state_dict(self) -> None:

restored_data_loader.shutdown()

def test_dataloader2_reset(self) -> None:

test_data_pipe = IterableWrapper(range(10))
reading_services = [None, TestReadingService(), MultiProcessingReadingService(num_workers=1)]

for reading_service in reading_services:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

# Functional Test: Ensure multiple sequential reads of DL2 is possible
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))

# Functional Test: Ensure that the creation of a new iterator invalidates the old one
it1 = iter(data_loader)
self.assertEqual(0, next(it1))
self.assertEqual(1, next(it1))
it2 = iter(data_loader)
self.assertEqual(0, next(it2))
self.assertEqual(1, next(it2))
with self.assertRaisesRegex(RuntimeError, "iterator has been invalidated"):
next(it1)
self.assertEqual(list(range(2, 10)), list(it2))

def test_dataloader2_delegate_attribute(self) -> None:
test_data_pipe = IterableWrapper(range(10))
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=TestReadingService())

# Functional Test: Ensure multiple sequential reads of DL2 is possible
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))

# Functional Test: Ensure that attribute/method of `dataloader._datapipe_iter` can be used
it = iter(data_loader)
self.assertEqual(1, it.return_one()) # type: ignore[attr-defined]


class DataLoader2ConsistencyTest(TestCase):
r"""
Expand Down
66 changes: 43 additions & 23 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,42 @@ class ConcurrencySpec:
persistent_workers: bool = False


class _DataLoader2Iterator(Iterator):
def __init__(self, dataloader: "DataLoader2", iterator_id: int):
self.dataloader = dataloader
self.iterator_id = iterator_id

def __next__(self) -> T_co:
if self.iterator_id == self.dataloader.valid_iterator_id:
self.dataloader._reset_iter = True
try:
return next(self.dataloader._datapipe_iter) # type: ignore[arg-type]
except PauseIteration:
raise StopIteration
except StopIteration:
if self.dataloader.reading_service is not None:
self.dataloader.reading_service.finalize_iteration()
raise
else:
if self.dataloader.reading_service is not None:
self.dataloader.reading_service.finalize_iteration()
raise RuntimeError(
"This iterator has been invalidated because another iterator has been created "
"from the same DataLoader2.\n"
"This may be caused multiple references to the same DataLoader2. "
"For feedback regarding this single iterator per DataLoader2 constraint, feel free "
"to comment on this issue: https://github.com/pytorch/data/issues/45."
)

def __getattr__(self, name):
"""
To delegate operations to ``dataloader._datapipe_iter``.
"""
if self.dataloader._datapipe_iter is None:
raise AttributeError
return getattr(self.dataloader._datapipe_iter, name)


class DataLoader2(Generic[T_co]):
def __init__(
self,
Expand All @@ -53,7 +89,7 @@ def __init__(
self.datapipe = datapipe
self._adapted: bool = False
self._datapipe_iter: Optional[Iterator[T_co]] = None
self._reset_iter: bool = True
self._reset_iter: bool = True # Sets to `False` when __iter__ starts, and `True` when `StopIteration``
# TODO(630): Some ReadingServices might want to validate adapters, we can add this feature
if datapipe_adapter_fn is None:
self.datapipe_adapter_fns = None
Expand All @@ -62,8 +98,9 @@ def __init__(
else:
self.datapipe_adapter_fns = [datapipe_adapter_fn]
self.reading_service = reading_service
self.reading_service_state: Optional[bytes] = None
self.reading_service_state: Optional[bytes] = None # is not `None` when `load_state_dict` is called
self._terminated: bool = False
self.valid_iterator_id: Optional[int] = None

if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
Expand All @@ -88,28 +125,10 @@ def __iter__(self) -> Iterator[T_co]:
self.reading_service.initialize_iteration()

self._datapipe_iter = iter(self.datapipe)

self._reset_iter = False

return self

def __next__(self) -> T_co:
if self._reset_iter:
raise StopIteration
try:
return next(self._datapipe_iter) # type: ignore[arg-type]
except PauseIteration:
raise StopIteration
except StopIteration:
if self.reading_service is not None:
self.reading_service.finalize_iteration()
self._reset_iter = True
raise

def __getattr__(self, name: str) -> Any:
if self._datapipe_iter is None:
raise AttributeError
return getattr(self._datapipe_iter, name)
self.valid_iterator_id = 0 if self.valid_iterator_id is None else self.valid_iterator_id + 1
return _DataLoader2Iterator(self, self.valid_iterator_id)

def __del__(self) -> None:
self.shutdown()
Expand Down Expand Up @@ -175,7 +194,8 @@ def load_state_dict(self, state: Dict[str, Any]) -> None:
# iterator has already been created: 1) iterator is just created 2) iterator is created and iter is exhausted
if self._datapipe_iter is not None:
raise RuntimeError(
"DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. Please create a new dataloader in order to use load state dict."
"DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. "
"Please create a new dataloader in order to use load state dict."
)

serialized_datapipe = state[SERIALIZED_DATAPIPE_KEY_NAME]
Expand Down
5 changes: 3 additions & 2 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ def init_datapipe_process(num_workers, worker_id, datapipe):

def initialize(self, datapipe: DataPipe) -> DataPipe:
if self.num_workers == 0:
# TODO(616): Warn and recommend usage of InPorcessReadingService
# TODO(616): Warn and recommend usage of InProcessReadingService
return datapipe
for worker_id in range(self.num_workers):
# TODO(617): Separate into function, because we also need to apply distributed seed and call it inside process
# TODO(617): Separate into function, because we also need to apply distributed seed
# and call it inside process
call_inside_process = functools.partial(self.init_datapipe_process, self.num_workers, worker_id)
ctx = mp.get_context(self.multiprocessing_context)
(process, req_queue, res_queue) = communication.eventloop.SpawnProcessForDataPipeline(
Expand Down

0 comments on commit fb1f63d

Please sign in to comment.