Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Prefetcher and Implement PinMemory IterDataPipe #1014

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Features described in this documentation are classified by release status:
dataloader2.rst
reading_service.rst


.. toctree::
:maxdepth: 2
:caption: Tutorial and Examples:
Expand Down
11 changes: 11 additions & 0 deletions docs/source/torchdata.datapipes.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ DataPipe Graph Visualization

to_graph

Commond Utility Functions
--------------------------------------
.. currentmodule:: torchdata.datapipes.utils

.. autosummary::
:nosignatures:
:toctree: generated/
:template: function.rst

pin_memory_fn


File Object and Stream Utility
-------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from unittest import TestCase

from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, PrototypeMultiProcessingReadingService
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper


Expand All @@ -29,9 +29,9 @@ def _add_one(x: int) -> int:
dp_parametrize = parametrize("dp", test_dps)


class TestPrototypeMultiProcessingReadingService(TestCase):
class TestMultiProcessingReadingService(TestCase):
r"""
This tests specific functionalities of PrototypeMultiProcessingReadingService, notably
This tests specific functionalities of MultiProcessingReadingService, notably
`pause`, `resume`, `snapshot`.
"""

Expand All @@ -40,7 +40,7 @@ def test_reading_service_pause_resume_0_worker(self, ctx) -> None:

# Functional Test: Verifies that this ReadingService will raise error when `pause/resume` is used
# with `num_workers = 0`
rs0 = PrototypeMultiProcessingReadingService(
rs0 = MultiProcessingReadingService(
num_workers=0, worker_prefetch_cnt=0, main_prefetch_cnt=0, multiprocessing_context=ctx
)
dl0: DataLoader2 = DataLoader2(dp1, reading_service=rs0)
Expand All @@ -64,7 +64,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_

# Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline
# properly pauses and resumes
rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=n_workers,
worker_prefetch_cnt=worker_prefetch_cnt,
main_prefetch_cnt=main_prefetch_cnt,
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_
def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:

# Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called
rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=n_workers,
worker_prefetch_cnt=worker_prefetch_cnt,
main_prefetch_cnt=main_prefetch_cnt,
Expand All @@ -117,7 +117,7 @@ def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefe
@parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 2)])
def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:

rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt
)

Expand Down Expand Up @@ -209,10 +209,10 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
# those DPs belong to a dispatching process and only do pause if worker_id == 0
# There might still be a race condition, need to look into the messages

# rs1 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
# rs1 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs4 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
# rss = [rs1, rs2, rs3, rs4]

# for n, rs in enumerate(rss):
Expand Down Expand Up @@ -284,7 +284,7 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
# pass


instantiate_parametrized_tests(TestPrototypeMultiProcessingReadingService)
instantiate_parametrized_tests(TestMultiProcessingReadingService)


if __name__ == "__main__":
Expand Down
44 changes: 43 additions & 1 deletion test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Dict

import expecttest
import torch.utils.data.datapipes.iter
import torch

import torchdata

Expand Down Expand Up @@ -42,6 +42,8 @@
)
from torchdata.datapipes.map import MapDataPipe, SequenceWrapper

skipIfNoCUDA = unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")


def test_torchdata_pytorch_consistency() -> None:
def extract_datapipe_names(module):
Expand All @@ -68,6 +70,14 @@ def extract_datapipe_names(module):
raise AssertionError(msg + "\n".join(sorted(missing_datapipes)))


def _convert_to_tensor(data):
if isinstance(data, dict):
return {k: _convert_to_tensor(v) for k, v in data.items()}
elif isinstance(data, list):
return [_convert_to_tensor(v) for v in data]
return torch.tensor(data)


class TestIterDataPipe(expecttest.TestCase):
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(10))
Expand Down Expand Up @@ -1475,6 +1485,38 @@ def test_random_splitter_iterdatapipe(self):
next(it_train)
next(it_valid) # No error, can keep going

@skipIfNoCUDA
def test_pin_memory(self):
# Tensor
dp = IterableWrapper([(i, i + 1) for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(d.is_pinned() for d in dp))

# List of Tensors
dp = IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for d0, d1 in dp))

# Dict of Tensors
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(v.is_pinned() for d in dp for v in d.values()))

# Dict of List of Tensors
dp = (
IterableWrapper([{str(i): [(i - 1, i), (i, i + 1)]} for i in range(10)])
.map(_convert_to_tensor)
.pin_memory()
)
self.assertTrue(all(v.is_pinned() for d in dp for batch in d.values() for v in batch))

# List of Dict of Tensors
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory()
self.assertTrue(all(v.is_pinned() for batch in dp for d in batch for v in d.values()))

# List of List of Tensors
dp = (
IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory()
)
self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for batch in dp for d0, d1 in batch))


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def _filter_by_module_availability(datapipes):
return [dp for dp in datapipes if dp[0] not in filter_set]


def _convert_to_tensor(data):
return torch.tensor(data)


class TestIterDataPipeSerialization(expecttest.TestCase):
def setUp(self):
self.temp_dir = create_temp_dir()
Expand Down Expand Up @@ -272,6 +276,7 @@ def test_serializable(self):
(),
{},
),
(iterdp.Prefetcher, None, (), {}),
(iterdp.ParquetDataFrameLoader, None, (), {"dtype": DTYPE}),
(iterdp.RarArchiveLoader, None, (), {}),
(
Expand Down
6 changes: 5 additions & 1 deletion torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@
CSVParserIterDataPipe as CSVParser,
LineReaderIterDataPipe as LineReader,
)
from torchdata.datapipes.iter.util.prefetcher import PrefetcherIterDataPipe as Prefetcher
from torchdata.datapipes.iter.util.prefetcher import (
PinMemoryIterDataPipe as PinMemory,
PrefetcherIterDataPipe as Prefetcher,
)
from torchdata.datapipes.iter.util.randomsplitter import RandomSplitterIterDataPipe as RandomSplitter
from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader
from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar
Expand Down Expand Up @@ -187,6 +190,7 @@
"OnlineReader",
"ParagraphAggregator",
"ParquetDataFrameLoader",
"PinMemory",
"Prefetcher",
"RandomSplitter",
"RarArchiveLoader",
Expand Down
1 change: 1 addition & 0 deletions torchdata/datapipes/iter/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ${init_base}
from .util.decompressor import CompressionType
from torchdata._constants import default_timeout_in_s
from torchdata.datapipes.map import MapDataPipe
from torchdata.datapipes.utils import pin_memory_fn
from torch.utils.data import DataChunk, IterableDataset, default_collate
from torch.utils.data.datapipes._typing import _DataPipeMeta
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
Expand Down
Loading