Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader2] Implementing single iterator constraint #700

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions test/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import unittest
from unittest import TestCase

from torch.utils.data.graph import DataPipe

from torchdata.dataloader2 import (
DataLoader2,
MultiProcessingReadingService,
Expand All @@ -19,15 +21,31 @@
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


class _ReadingServiceWrapper:
def __init__(self, dp):
self.dp = dp

def __iter__(self):
self.it = iter(self.dp)
return self

def __next__(self):
return next(self.it)

@staticmethod
def return_one():
return 1
Comment on lines +24 to +37
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A method that dataloader._datapipe_iter will have



class TestReadingService(ReadingServiceInterface):
def initialize(self, dp: IterDataPipe) -> IterDataPipe:
return dp
def initialize(self, dp: DataPipe) -> DataPipe:
return _ReadingServiceWrapper(dp) # type: ignore[return-value]


class DataLoader2Test(TestCase):
def test_dataloader2(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader = DataLoader2(datapipe=test_data_pipe)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)

expected_batch = 0
for batch in iter(data_loader):
Expand All @@ -36,12 +54,12 @@ def test_dataloader2(self) -> None:

def test_dataloader2_shutdown(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader = DataLoader2(datapipe=test_data_pipe)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
data_loader.shutdown()

def test_dataloader2_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader = DataLoader2(datapipe=test_data_pipe)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)

state = data_loader.state_dict()
self.assertIsNotNone(state)
Expand All @@ -52,7 +70,7 @@ def test_dataloader2_state_dict(self) -> None:
def test_dataloader2_reading_service(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = TestReadingService()
data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

expected_batch = 0
for batch in iter(data_loader):
Expand All @@ -62,7 +80,7 @@ def test_dataloader2_reading_service(self) -> None:
def test_dataloader2_multi_process_reading_service(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = MultiProcessingReadingService()
data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

expected_batch = 0
for batch in iter(data_loader):
Expand All @@ -72,7 +90,7 @@ def test_dataloader2_multi_process_reading_service(self) -> None:
def test_dataloader2_load_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = TestReadingService()
data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

batch = next(iter(data_loader))
self.assertEqual(batch, 0)
Expand All @@ -84,7 +102,7 @@ def test_dataloader2_load_state_dict(self) -> None:
data_loader.shutdown()

test_data_pipe_2 = IterableWrapper(range(5))
restored_data_loader = DataLoader2(datapipe=test_data_pipe_2, reading_service=reading_service)
restored_data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe_2, reading_service=reading_service)
restored_data_loader.load_state_dict(state)

restored_data_loader_datapipe = restored_data_loader.datapipe
Expand All @@ -104,6 +122,42 @@ def test_dataloader2_load_state_dict(self) -> None:

restored_data_loader.shutdown()

def test_dataloader2_reset(self) -> None:

test_data_pipe = IterableWrapper(range(10))
reading_services = [None, TestReadingService(), MultiProcessingReadingService(num_workers=1)]

for reading_service in reading_services:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

# Functional Test: Ensure multiple sequential reads of DL2 is possible
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))

# Functional Test: Ensure that the creation of a new iterator invalidates the old one
it1 = iter(data_loader)
self.assertEqual(0, next(it1))
self.assertEqual(1, next(it1))
it2 = iter(data_loader)
self.assertEqual(0, next(it2))
self.assertEqual(1, next(it2))
with self.assertRaisesRegex(RuntimeError, "iterator has been invalidated"):
next(it1)
self.assertEqual(list(range(2, 10)), list(it2))

def test_dataloader2_delegate_attribute(self) -> None:
test_data_pipe = IterableWrapper(range(10))
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=TestReadingService())

# Functional Test: Ensure multiple sequential reads of DL2 is possible
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))

# Functional Test: Ensure that attribute/method of `dataloader._datapipe_iter` can be used
it = iter(data_loader)
self.assertEqual(1, it.return_one()) # type: ignore[attr-defined]


class DataLoader2ConsistencyTest(TestCase):
r"""
Expand Down
3 changes: 2 additions & 1 deletion torchdata/dataloader2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.


from .dataloader2 import DataLoader2
from .dataloader2 import DataLoader2, DataLoader2Iterator
from .error import PauseIteration
from .reading_service import (
MultiProcessingReadingService,
Expand All @@ -16,6 +16,7 @@

__all__ = [
"DataLoader2",
"DataLoader2Iterator",
"MultiProcessingReadingService",
"PauseIteration",
"PrototypeMultiProcessingReadingService",
Expand Down
66 changes: 43 additions & 23 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,42 @@ class ConcurrencySpec:
persistent_workers: bool = False


class DataLoader2Iterator(Iterator):
def __init__(self, dataloader: "DataLoader2", iterator_id: int):
self.dataloader = dataloader
self.iterator_id = iterator_id

def __next__(self) -> T_co:
NivekT marked this conversation as resolved.
Show resolved Hide resolved
if self.iterator_id == self.dataloader.valid_iterator_id:
self.dataloader._reset_iter = True
NivekT marked this conversation as resolved.
Show resolved Hide resolved
try:
return next(self.dataloader._datapipe_iter) # type: ignore[arg-type]
except PauseIteration:
raise StopIteration
except StopIteration:
if self.dataloader.reading_service is not None:
self.dataloader.reading_service.finalize_iteration()
raise
else:
if self.dataloader.reading_service is not None:
self.dataloader.reading_service.finalize_iteration()
raise RuntimeError(
"This iterator has been invalidated because another iterator has been created "
"from the same DataLoader2.\n"
"This may be caused multiple references to the same DataLoader2. "
"For feedback regarding this single iterator per DataLoader2 constraint, feel free "
"to comment on this issue: https://github.com/pytorch/data/issues/45."
)

def __getattr__(self, name):
"""
To delegate operations to ``dataloader._datapipe_iter``.
"""
if self.dataloader._datapipe_iter is None:
raise AttributeError
return getattr(self.dataloader._datapipe_iter, name)


class DataLoader2(Generic[T_co]):
def __init__(
self,
Expand All @@ -53,7 +89,7 @@ def __init__(
self.datapipe = datapipe
self._adapted: bool = False
self._datapipe_iter: Optional[Iterator[T_co]] = None
self._reset_iter: bool = True
self._reset_iter: bool = True # Sets to `False` when __iter__ starts, and `True` when `StopIteration``
# TODO(630): Some ReadingServices might want to validate adapters, we can add this feature
if datapipe_adapter_fn is None:
self.datapipe_adapter_fns = None
Expand All @@ -62,8 +98,9 @@ def __init__(
else:
self.datapipe_adapter_fns = [datapipe_adapter_fn]
self.reading_service = reading_service
self.reading_service_state: Optional[bytes] = None
self.reading_service_state: Optional[bytes] = None # is not `None` when `load_state_dict` is called
self._terminated: bool = False
self.valid_iterator_id: Optional[int] = None

if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
Expand All @@ -88,28 +125,10 @@ def __iter__(self) -> Iterator[T_co]:
self.reading_service.initialize_iteration()

self._datapipe_iter = iter(self.datapipe)
ejguan marked this conversation as resolved.
Show resolved Hide resolved

self._reset_iter = False

return self

def __next__(self) -> T_co:
if self._reset_iter:
raise StopIteration
try:
return next(self._datapipe_iter) # type: ignore[arg-type]
except PauseIteration:
raise StopIteration
except StopIteration:
if self.reading_service is not None:
self.reading_service.finalize_iteration()
self._reset_iter = True
raise
NivekT marked this conversation as resolved.
Show resolved Hide resolved

def __getattr__(self, name: str) -> Any:
if self._datapipe_iter is None:
raise AttributeError
return getattr(self._datapipe_iter, name)
self.valid_iterator_id = 0 if self.valid_iterator_id is None else self.valid_iterator_id + 1
return DataLoader2Iterator(self, self.valid_iterator_id)

def __del__(self) -> None:
self.shutdown()
Expand Down Expand Up @@ -175,7 +194,8 @@ def load_state_dict(self, state: Dict[str, Any]) -> None:
# iterator has already been created: 1) iterator is just created 2) iterator is created and iter is exhausted
if self._datapipe_iter is not None:
raise RuntimeError(
"DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. Please create a new dataloader in order to use load state dict."
"DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. "
"Please create a new dataloader in order to use load state dict."
)

serialized_datapipe = state[SERIALIZED_DATAPIPE_KEY_NAME]
Expand Down
5 changes: 3 additions & 2 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ def init_datapipe_process(num_workers, worker_id, datapipe):

def initialize(self, datapipe: DataPipe) -> DataPipe:
if self.num_workers == 0:
# TODO(616): Warn and recommend usage of InPorcessReadingService
# TODO(616): Warn and recommend usage of InProcessReadingService
return datapipe
for worker_id in range(self.num_workers):
# TODO(617): Separate into function, because we also need to apply distributed seed and call it inside process
# TODO(617): Separate into function, because we also need to apply distributed seed
# and call it inside process
call_inside_process = functools.partial(self.init_datapipe_process, self.num_workers, worker_id)
ctx = mp.get_context(self.multiprocessing_context)
(process, req_queue, res_queue) = communication.eventloop.SpawnProcessForDataPipeline(
Expand Down