Skip to content

Commit

Permalink
Update Prefetcher and Implement PinMemory IterDataPipe (#1014)
Browse files Browse the repository at this point in the history
Summary:
Fixes #1013

## Changes

- Simplify the control flow of prefetcher
  - Delay Exception raised from thread worker to main thread in `__iter__`
  - Stop prefetching whenever Exception is received
  - As long as `stop_iteration` is not turned on or `buffer` is not empty, continue yielding data from `__iter__`.
  - Add serialization test
- Add `PinMemory` DataPipe
  -  `is_replciable() -> False` to keep it in the main process
  - Add unit tests
- Update `test_proto_multi_rs.py` to `test_mprs.py`

Pull Request resolved: #1014

Reviewed By: NivekT

Differential Revision: D43329696

Pulled By: ejguan

fbshipit-source-id: da4326dbe2388f4e23b9a1a3a5c43da09d29185a
  • Loading branch information
ejguan committed Feb 21, 2023
1 parent dc72842 commit 6c4e23d
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 46 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Features described in this documentation are classified by release status:
dataloader2.rst
reading_service.rst


.. toctree::
:maxdepth: 2
:caption: Tutorial and Examples:
Expand Down
11 changes: 11 additions & 0 deletions docs/source/torchdata.datapipes.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ DataPipe Graph Visualization

to_graph

Commond Utility Functions
--------------------------------------
.. currentmodule:: torchdata.datapipes.utils

.. autosummary::
:nosignatures:
:toctree: generated/
:template: function.rst

pin_memory_fn


File Object and Stream Utility
-------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from unittest import TestCase

from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, PrototypeMultiProcessingReadingService
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper


Expand All @@ -29,9 +29,9 @@ def _add_one(x: int) -> int:
dp_parametrize = parametrize("dp", test_dps)


class TestPrototypeMultiProcessingReadingService(TestCase):
class TestMultiProcessingReadingService(TestCase):
r"""
This tests specific functionalities of PrototypeMultiProcessingReadingService, notably
This tests specific functionalities of MultiProcessingReadingService, notably
`pause`, `resume`, `snapshot`.
"""

Expand All @@ -40,7 +40,7 @@ def test_reading_service_pause_resume_0_worker(self, ctx) -> None:

# Functional Test: Verifies that this ReadingService will raise error when `pause/resume` is used
# with `num_workers = 0`
rs0 = PrototypeMultiProcessingReadingService(
rs0 = MultiProcessingReadingService(
num_workers=0, worker_prefetch_cnt=0, main_prefetch_cnt=0, multiprocessing_context=ctx
)
dl0: DataLoader2 = DataLoader2(dp1, reading_service=rs0)
Expand All @@ -64,7 +64,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_

# Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline
# properly pauses and resumes
rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=n_workers,
worker_prefetch_cnt=worker_prefetch_cnt,
main_prefetch_cnt=main_prefetch_cnt,
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_
def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:

# Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called
rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=n_workers,
worker_prefetch_cnt=worker_prefetch_cnt,
main_prefetch_cnt=main_prefetch_cnt,
Expand All @@ -117,7 +117,7 @@ def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefe
@parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 2)])
def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:

rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt
)

Expand Down Expand Up @@ -209,10 +209,10 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
# those DPs belong to a dispatching process and only do pause if worker_id == 0
# There might still be a race condition, need to look into the messages

# rs1 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
# rs1 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs4 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
# rss = [rs1, rs2, rs3, rs4]

# for n, rs in enumerate(rss):
Expand Down Expand Up @@ -284,7 +284,7 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
# pass


instantiate_parametrized_tests(TestPrototypeMultiProcessingReadingService)
instantiate_parametrized_tests(TestMultiProcessingReadingService)


if __name__ == "__main__":
Expand Down
44 changes: 43 additions & 1 deletion test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Dict

import expecttest
import torch.utils.data.datapipes.iter
import torch

import torchdata

Expand Down Expand Up @@ -42,6 +42,8 @@
)
from torchdata.datapipes.map import MapDataPipe, SequenceWrapper

skipIfNoCUDA = unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")


def test_torchdata_pytorch_consistency() -> None:
def extract_datapipe_names(module):
Expand All @@ -68,6 +70,14 @@ def extract_datapipe_names(module):
raise AssertionError(msg + "\n".join(sorted(missing_datapipes)))


def _convert_to_tensor(data):
if isinstance(data, dict):
return {k: _convert_to_tensor(v) for k, v in data.items()}
elif isinstance(data, list):
return [_convert_to_tensor(v) for v in data]
return torch.tensor(data)


class TestIterDataPipe(expecttest.TestCase):
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(10))
Expand Down Expand Up @@ -1475,6 +1485,38 @@ def test_random_splitter_iterdatapipe(self):
next(it_train)
next(it_valid) # No error, can keep going

@skipIfNoCUDA
def test_pin_memory(self):
# Tensor
dp = IterableWrapper([(i, i + 1) for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(d.is_pinned() for d in dp))

# List of Tensors
dp = IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for d0, d1 in dp))

# Dict of Tensors
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(v.is_pinned() for d in dp for v in d.values()))

# Dict of List of Tensors
dp = (
IterableWrapper([{str(i): [(i - 1, i), (i, i + 1)]} for i in range(10)])
.map(_convert_to_tensor)
.pin_memory()
)
self.assertTrue(all(v.is_pinned() for d in dp for batch in d.values() for v in batch))

# List of Dict of Tensors
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory()
self.assertTrue(all(v.is_pinned() for batch in dp for d in batch for v in d.values()))

# List of List of Tensors
dp = (
IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory()
)
self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for batch in dp for d0, d1 in batch))


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def _filter_by_module_availability(datapipes):
return [dp for dp in datapipes if dp[0] not in filter_set]


def _convert_to_tensor(data):
return torch.tensor(data)


class TestIterDataPipeSerialization(expecttest.TestCase):
def setUp(self):
self.temp_dir = create_temp_dir()
Expand Down Expand Up @@ -272,6 +276,7 @@ def test_serializable(self):
(),
{},
),
(iterdp.Prefetcher, None, (), {}),
(iterdp.ParquetDataFrameLoader, None, (), {"dtype": DTYPE}),
(iterdp.RarArchiveLoader, None, (), {}),
(
Expand Down
6 changes: 5 additions & 1 deletion torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@
CSVParserIterDataPipe as CSVParser,
LineReaderIterDataPipe as LineReader,
)
from torchdata.datapipes.iter.util.prefetcher import PrefetcherIterDataPipe as Prefetcher
from torchdata.datapipes.iter.util.prefetcher import (
PinMemoryIterDataPipe as PinMemory,
PrefetcherIterDataPipe as Prefetcher,
)
from torchdata.datapipes.iter.util.randomsplitter import RandomSplitterIterDataPipe as RandomSplitter
from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader
from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar
Expand Down Expand Up @@ -187,6 +190,7 @@
"OnlineReader",
"ParagraphAggregator",
"ParquetDataFrameLoader",
"PinMemory",
"Prefetcher",
"RandomSplitter",
"RarArchiveLoader",
Expand Down
1 change: 1 addition & 0 deletions torchdata/datapipes/iter/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ${init_base}
from .util.decompressor import CompressionType
from torchdata._constants import default_timeout_in_s
from torchdata.datapipes.map import MapDataPipe
from torchdata.datapipes.utils import pin_memory_fn
from torch.utils.data import DataChunk, IterableDataset, default_collate
from torch.utils.data.datapipes._typing import _DataPipeMeta
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
Expand Down
Loading

0 comments on commit 6c4e23d

Please sign in to comment.