Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader2] Implementing single iterator constraint #700

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions test/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import unittest
from unittest import TestCase

from torch.utils.data.graph import DataPipe

from torchdata.dataloader2 import (
DataLoader2,
MultiProcessingReadingService,
Expand All @@ -20,14 +22,14 @@


class TestReadingService(ReadingServiceInterface):
def initialize(self, dp: IterDataPipe) -> IterDataPipe:
def initialize(self, dp: DataPipe) -> DataPipe:
return dp


class DataLoader2Test(TestCase):
def test_dataloader2(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader = DataLoader2(datapipe=test_data_pipe)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)

expected_batch = 0
for batch in iter(data_loader):
Expand All @@ -36,12 +38,12 @@ def test_dataloader2(self) -> None:

def test_dataloader2_shutdown(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader = DataLoader2(datapipe=test_data_pipe)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
data_loader.shutdown()

def test_dataloader2_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader = DataLoader2(datapipe=test_data_pipe)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)

state = data_loader.state_dict()
self.assertIsNotNone(state)
Expand All @@ -52,7 +54,7 @@ def test_dataloader2_state_dict(self) -> None:
def test_dataloader2_reading_service(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = TestReadingService()
data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

expected_batch = 0
for batch in iter(data_loader):
Expand All @@ -62,7 +64,7 @@ def test_dataloader2_reading_service(self) -> None:
def test_dataloader2_multi_process_reading_service(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = MultiProcessingReadingService()
data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

expected_batch = 0
for batch in iter(data_loader):
Expand All @@ -72,7 +74,7 @@ def test_dataloader2_multi_process_reading_service(self) -> None:
def test_dataloader2_load_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = TestReadingService()
data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

batch = next(iter(data_loader))
self.assertEqual(batch, 0)
Expand All @@ -84,7 +86,7 @@ def test_dataloader2_load_state_dict(self) -> None:
data_loader.shutdown()

test_data_pipe_2 = IterableWrapper(range(5))
restored_data_loader = DataLoader2(datapipe=test_data_pipe_2, reading_service=reading_service)
restored_data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe_2, reading_service=reading_service)
restored_data_loader.load_state_dict(state)

restored_data_loader_datapipe = restored_data_loader.datapipe
Expand All @@ -104,6 +106,30 @@ def test_dataloader2_load_state_dict(self) -> None:

restored_data_loader.shutdown()

def test_dataloader2_reset(self) -> None:

test_data_pipe = IterableWrapper(range(10))
reading_services = [None, TestReadingService(), MultiProcessingReadingService(num_workers=1)]

for reading_service in reading_services:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

# Functional Test: Ensure multiple sequential reads of DL2 is possible
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))

# Functional Test: Ensure that the creation of a new iterator invalidates the old one
it1 = iter(data_loader)
self.assertEqual(0, next(it1))
self.assertEqual(1, next(it1))
it2 = iter(data_loader)
self.assertEqual(0, next(it2))
self.assertEqual(1, next(it2))
with self.assertRaisesRegex(RuntimeError, "iterator has been invalidated"):
next(it1)
self.assertEqual(list(range(2, 10)), list(it2))


class DataLoader2ConsistencyTest(TestCase):
r"""
Expand Down
63 changes: 45 additions & 18 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,40 @@ class ConcurrencySpec:
persistent_workers: bool = False


class _DataLoader2Iterator(Iterator):
def __init__(self, dataloader: "DataLoader2", iterator_id: int):
self.dataloader = dataloader
self.iterator_id = iterator_id

def __next__(self) -> T_co:
NivekT marked this conversation as resolved.
Show resolved Hide resolved
if self.iterator_id == self.dataloader.valid_iterator_id:
self.dataloader._reset_iter = True
NivekT marked this conversation as resolved.
Show resolved Hide resolved
try:
return next(self.dataloader._datapipe_iter) # type: ignore[arg-type]
except PauseIteration:
raise StopIteration
except StopIteration:
if self.dataloader.reading_service is not None:
self.dataloader.reading_service.finalize_iteration()
raise
else:
raise RuntimeError(
"This iterator has been invalidated because another iterator has been created "
"from the same DataLoader2.\n"
"This may be caused multiple references to the same DataLoader2. "
"For feedback regarding this single iterator per DataLoader2 constraint, feel free "
"to comment on this issue: https://github.com/pytorch/data/issues/45."
)

def __getattr__(self, name):
"""
To delegate operations to ``dataloader._datapipe_iter``.
"""
if self.dataloader._datapipe_iter is None:
raise AttributeError
return getattr(self.dataloader_datapipe_iter, name)
NivekT marked this conversation as resolved.
Show resolved Hide resolved


class DataLoader2(Generic[T_co]):
def __init__(
self,
Expand All @@ -53,7 +87,7 @@ def __init__(
self.datapipe = datapipe
self._adapted: bool = False
self._datapipe_iter: Optional[Iterator[T_co]] = None
self._reset_iter: bool = True
self._reset_iter: bool = True # Sets to `False` when __iter__ starts, and `True` when `StopIteration``
# TODO(630): Some ReadingServices might want to validate adapters, we can add this feature
if datapipe_adapter_fn is None:
self.datapipe_adapter_fns = None
Expand All @@ -62,8 +96,9 @@ def __init__(
else:
self.datapipe_adapter_fns = [datapipe_adapter_fn]
self.reading_service = reading_service
self.reading_service_state: Optional[bytes] = None
self.reading_service_state: Optional[bytes] = None # is not `None` when `load_state_dict` is called
self._terminated: bool = False
self.valid_iterator_id: Optional[int] = None

if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
Expand All @@ -88,25 +123,16 @@ def __iter__(self) -> Iterator[T_co]:
self.reading_service.initialize_iteration()

self._datapipe_iter = iter(self.datapipe)
ejguan marked this conversation as resolved.
Show resolved Hide resolved

self._reset_iter = False

return self

def __next__(self) -> T_co:
if self._reset_iter:
raise StopIteration
try:
return next(self._datapipe_iter) # type: ignore[arg-type]
except PauseIteration:
raise StopIteration
except StopIteration:
if self.reading_service is not None:
self.reading_service.finalize_iteration()
self._reset_iter = True
raise
NivekT marked this conversation as resolved.
Show resolved Hide resolved
self.valid_iterator_id = 0 if self.valid_iterator_id is None else self.valid_iterator_id + 1
return _DataLoader2Iterator(self, self.valid_iterator_id)

def __getattr__(self, name: str) -> Any:
"""
Delegate methods (e.g. `limit`, `pause`, `resume`, etc) to the iterator
created by the ReadingService and DataPipe.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we directly delegate to datapipe_iter from dataloader_iter, we don't need this function.

Copy link
Contributor Author

@NivekT NivekT Jul 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with this change, but I would flag that methods such as state_dict and shutdown will not be delegated to DataLoader2. I am going to change the implementation to do that. As long as that is not an issue, we can delegate directly to dataloader_iter.

Copy link
Contributor Author

@NivekT NivekT Jul 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do we want to allow users to do dl2.limit(), dl2.resume(), and etc? Or we want them to always invoke those from the iterator?

If we want the former, we will need to keep this method.

Copy link
Contributor

@ejguan ejguan Jul 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, users should only invoke such APIs on iterator object. I am fine either.

cc: @Miiira

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think invoking from one single place (i.e. iterator) rather than multiple places is better and less error-prone.

The internal test is calling .resume on the iterator, so I think it is fine to remove the API from the DataLoader2 class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we only need limit, resume on iterator. I'm good with removing this from DataLoader2. Maybe we also want to look at Lightning DataModule train_dataloader return type with this change

if self._datapipe_iter is None:
raise AttributeError
return getattr(self._datapipe_iter, name)
Expand Down Expand Up @@ -175,7 +201,8 @@ def load_state_dict(self, state: Dict[str, Any]) -> None:
# iterator has already been created: 1) iterator is just created 2) iterator is created and iter is exhausted
if self._datapipe_iter is not None:
raise RuntimeError(
"DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. Please create a new dataloader in order to use load state dict."
"DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. "
"Please create a new dataloader in order to use load state dict."
)

serialized_datapipe = state[SERIALIZED_DATAPIPE_KEY_NAME]
Expand Down
5 changes: 3 additions & 2 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ def init_datapipe_process(num_workers, worker_id, datapipe):

def initialize(self, datapipe: DataPipe) -> DataPipe:
if self.num_workers == 0:
# TODO(616): Warn and recommend usage of InPorcessReadingService
# TODO(616): Warn and recommend usage of InProcessReadingService
return datapipe
for worker_id in range(self.num_workers):
# TODO(617): Separate into function, because we also need to apply distributed seed and call it inside process
# TODO(617): Separate into function, because we also need to apply distributed seed
# and call it inside process
call_inside_process = functools.partial(self.init_datapipe_process, self.num_workers, worker_id)
ctx = mp.get_context(self.multiprocessing_context)
(process, req_queue, res_queue) = communication.eventloop.SpawnProcessForDataPipeline(
Expand Down