From 6f618c7b113045377b588f4499ba07615796c0df Mon Sep 17 00:00:00 2001 From: erjia Date: Fri, 17 Feb 2023 13:17:13 -0800 Subject: [PATCH] Update Prefetcher and Implement PinMemory IterDataPipe (#1014) Summary: Fixes https://github.com/pytorch/data/issues/1013 ## Changes - Simplify the control flow of prefetcher - Delay Exception raised from thread worker to main thread in `__iter__` - Stop prefetching whenever Exception is received - As long as `stop_iteration` is not turned on or `buffer` is not empty, continue yielding data from `__iter__`. - Add serialization test - Add `PinMemory` DataPipe - `is_replciable() -> False` to keep it in the main process - Add unit tests - Update `test_proto_multi_rs.py` to `test_mprs.py` Pull Request resolved: https://github.com/pytorch/data/pull/1014 Reviewed By: NivekT Differential Revision: D43329696 Pulled By: ejguan fbshipit-source-id: da4326dbe2388f4e23b9a1a3a5c43da09d29185a --- docs/source/index.rst | 1 + docs/source/torchdata.datapipes.utils.rst | 11 ++ .../{test_proto_multi_rs.py => test_mprs.py} | 24 +-- test/test_iterdatapipe.py | 44 +++++- test/test_serialization.py | 5 + torchdata/datapipes/iter/__init__.py | 6 +- torchdata/datapipes/iter/__init__.pyi.in | 1 + torchdata/datapipes/iter/util/prefetcher.py | 147 ++++++++++++++---- torchdata/datapipes/utils/__init__.py | 7 +- torchdata/datapipes/utils/pin_memory.py | 35 +++++ 10 files changed, 235 insertions(+), 46 deletions(-) rename test/dataloader2/{test_proto_multi_rs.py => test_mprs.py} (91%) create mode 100644 torchdata/datapipes/utils/pin_memory.py diff --git a/docs/source/index.rst b/docs/source/index.rst index d61524cb8..5aa5895af 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,6 +42,7 @@ Features described in this documentation are classified by release status: dataloader2.rst reading_service.rst + .. toctree:: :maxdepth: 2 :caption: Tutorial and Examples: diff --git a/docs/source/torchdata.datapipes.utils.rst b/docs/source/torchdata.datapipes.utils.rst index b8a6175e7..ad96e2430 100644 --- a/docs/source/torchdata.datapipes.utils.rst +++ b/docs/source/torchdata.datapipes.utils.rst @@ -15,6 +15,17 @@ DataPipe Graph Visualization to_graph +Commond Utility Functions +-------------------------------------- +.. currentmodule:: torchdata.datapipes.utils + +.. autosummary:: + :nosignatures: + :toctree: generated/ + :template: function.rst + + pin_memory_fn + File Object and Stream Utility ------------------------------------- diff --git a/test/dataloader2/test_proto_multi_rs.py b/test/dataloader2/test_mprs.py similarity index 91% rename from test/dataloader2/test_proto_multi_rs.py rename to test/dataloader2/test_mprs.py index 52c5c239e..634094e49 100644 --- a/test/dataloader2/test_proto_multi_rs.py +++ b/test/dataloader2/test_mprs.py @@ -10,7 +10,7 @@ from unittest import TestCase from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize -from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, PrototypeMultiProcessingReadingService +from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService from torchdata.datapipes.iter import IterableWrapper @@ -29,9 +29,9 @@ def _add_one(x: int) -> int: dp_parametrize = parametrize("dp", test_dps) -class TestPrototypeMultiProcessingReadingService(TestCase): +class TestMultiProcessingReadingService(TestCase): r""" - This tests specific functionalities of PrototypeMultiProcessingReadingService, notably + This tests specific functionalities of MultiProcessingReadingService, notably `pause`, `resume`, `snapshot`. """ @@ -40,7 +40,7 @@ def test_reading_service_pause_resume_0_worker(self, ctx) -> None: # Functional Test: Verifies that this ReadingService will raise error when `pause/resume` is used # with `num_workers = 0` - rs0 = PrototypeMultiProcessingReadingService( + rs0 = MultiProcessingReadingService( num_workers=0, worker_prefetch_cnt=0, main_prefetch_cnt=0, multiprocessing_context=ctx ) dl0: DataLoader2 = DataLoader2(dp1, reading_service=rs0) @@ -64,7 +64,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_ # Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline # properly pauses and resumes - rs = PrototypeMultiProcessingReadingService( + rs = MultiProcessingReadingService( num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt, @@ -93,7 +93,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_ def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None: # Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called - rs = PrototypeMultiProcessingReadingService( + rs = MultiProcessingReadingService( num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt, @@ -117,7 +117,7 @@ def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefe @parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 2)]) def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None: - rs = PrototypeMultiProcessingReadingService( + rs = MultiProcessingReadingService( num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt ) @@ -209,10 +209,10 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr # those DPs belong to a dispatching process and only do pause if worker_id == 0 # There might still be a race condition, need to look into the messages - # rs1 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0) - # rs2 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2) - # rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0) - # rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2) + # rs1 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0) + # rs2 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2) + # rs3 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0) + # rs4 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2) # rss = [rs1, rs2, rs3, rs4] # for n, rs in enumerate(rss): @@ -284,7 +284,7 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr # pass -instantiate_parametrized_tests(TestPrototypeMultiProcessingReadingService) +instantiate_parametrized_tests(TestMultiProcessingReadingService) if __name__ == "__main__": diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 91719c9fe..c591b924c 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -14,7 +14,7 @@ from typing import Dict import expecttest -import torch.utils.data.datapipes.iter +import torch import torchdata @@ -42,6 +42,8 @@ ) from torchdata.datapipes.map import MapDataPipe, SequenceWrapper +skipIfNoCUDA = unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available") + def test_torchdata_pytorch_consistency() -> None: def extract_datapipe_names(module): @@ -68,6 +70,14 @@ def extract_datapipe_names(module): raise AssertionError(msg + "\n".join(sorted(missing_datapipes))) +def _convert_to_tensor(data): + if isinstance(data, dict): + return {k: _convert_to_tensor(v) for k, v in data.items()} + elif isinstance(data, list): + return [_convert_to_tensor(v) for v in data] + return torch.tensor(data) + + class TestIterDataPipe(expecttest.TestCase): def test_in_memory_cache_holder_iterdatapipe(self) -> None: source_dp = IterableWrapper(range(10)) @@ -1475,6 +1485,38 @@ def test_random_splitter_iterdatapipe(self): next(it_train) next(it_valid) # No error, can keep going + @skipIfNoCUDA + def test_pin_memory(self): + # Tensor + dp = IterableWrapper([(i, i + 1) for i in range(10)]).map(_convert_to_tensor).pin_memory() + self.assertTrue(all(d.is_pinned() for d in dp)) + + # List of Tensors + dp = IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).pin_memory() + self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for d0, d1 in dp)) + + # Dict of Tensors + dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).pin_memory() + self.assertTrue(all(v.is_pinned() for d in dp for v in d.values())) + + # Dict of List of Tensors + dp = ( + IterableWrapper([{str(i): [(i - 1, i), (i, i + 1)]} for i in range(10)]) + .map(_convert_to_tensor) + .pin_memory() + ) + self.assertTrue(all(v.is_pinned() for d in dp for batch in d.values() for v in batch)) + + # List of Dict of Tensors + dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory() + self.assertTrue(all(v.is_pinned() for batch in dp for d in batch for v in d.values())) + + # List of List of Tensors + dp = ( + IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory() + ) + self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for batch in dp for d0, d1 in batch)) + if __name__ == "__main__": unittest.main() diff --git a/test/test_serialization.py b/test/test_serialization.py index 10051f4f0..a837a3948 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -92,6 +92,10 @@ def _filter_by_module_availability(datapipes): return [dp for dp in datapipes if dp[0] not in filter_set] +def _convert_to_tensor(data): + return torch.tensor(data) + + class TestIterDataPipeSerialization(expecttest.TestCase): def setUp(self): self.temp_dir = create_temp_dir() @@ -272,6 +276,7 @@ def test_serializable(self): (), {}, ), + (iterdp.Prefetcher, None, (), {}), (iterdp.ParquetDataFrameLoader, None, (), {"dtype": DTYPE}), (iterdp.RarArchiveLoader, None, (), {}), ( diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 06ecb9cda..61c3fe25d 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -108,7 +108,10 @@ CSVParserIterDataPipe as CSVParser, LineReaderIterDataPipe as LineReader, ) -from torchdata.datapipes.iter.util.prefetcher import PrefetcherIterDataPipe as Prefetcher +from torchdata.datapipes.iter.util.prefetcher import ( + PinMemoryIterDataPipe as PinMemory, + PrefetcherIterDataPipe as Prefetcher, +) from torchdata.datapipes.iter.util.randomsplitter import RandomSplitterIterDataPipe as RandomSplitter from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar @@ -187,6 +190,7 @@ "OnlineReader", "ParagraphAggregator", "ParquetDataFrameLoader", + "PinMemory", "Prefetcher", "RandomSplitter", "RarArchiveLoader", diff --git a/torchdata/datapipes/iter/__init__.pyi.in b/torchdata/datapipes/iter/__init__.pyi.in index 141306b4f..79f906a01 100644 --- a/torchdata/datapipes/iter/__init__.pyi.in +++ b/torchdata/datapipes/iter/__init__.pyi.in @@ -10,6 +10,7 @@ ${init_base} from .util.decompressor import CompressionType from torchdata._constants import default_timeout_in_s from torchdata.datapipes.map import MapDataPipe +from torchdata.datapipes.utils import pin_memory_fn from torch.utils.data import DataChunk, IterableDataset, default_collate from torch.utils.data.datapipes._typing import _DataPipeMeta from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index 6c72cee7a..7fb02b1ac 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -10,8 +10,11 @@ from collections import deque from typing import Deque, Optional +import torch + from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe +from torchdata.datapipes.utils import pin_memory_fn PRODUCER_SLEEP_INTERVAL = 0.0001 # Interval between buffer fulfillment checks CONSUMER_SLEEP_INTERVAL = 0.0001 # Interval between checking items availability in buffer @@ -19,16 +22,16 @@ class _PrefetchData: def __init__(self, source_datapipe, buffer_size: int): - self.run_prefetcher = True + self.run_prefetcher: bool = True self.prefetch_buffer: Deque = deque() self.buffer_size: int = buffer_size self.source_datapipe = source_datapipe - self.stop_iteration = False + self.stop_iteration: bool = False @functional_datapipe("prefetch") class PrefetcherIterDataPipe(IterDataPipe): - """ + r""" Prefetches elements from the source DataPipe and puts them into a buffer (functional name: ``prefetch``). Prefetching performs the operations (e.g. I/O, computations) of the DataPipes up to this one ahead of time and stores the result in the buffer, ready to be consumed by the subsequent DataPipe. It has no effect aside @@ -59,54 +62,45 @@ def __init__(self, source_datapipe, buffer_size: int = 10): @staticmethod def thread_worker(prefetch_data: _PrefetchData): - # Lazily import to prevent circular import - from torchdata.dataloader2 import communication - itr = iter(prefetch_data.source_datapipe) while not prefetch_data.stop_iteration: + # Run if not paused while prefetch_data.run_prefetcher: - if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size and not prefetch_data.stop_iteration: + if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size: try: item = next(itr) prefetch_data.prefetch_buffer.append(item) - except StopIteration: - prefetch_data.stop_iteration = True - except communication.iter.InvalidStateResetRequired: - prefetch_data.stop_iteration = True - except communication.iter.TerminateRequired: + except Exception as e: prefetch_data.run_prefetcher = False prefetch_data.stop_iteration = True - except Exception as e: prefetch_data.prefetch_buffer.append(e) - break - elif prefetch_data.stop_iteration and len(prefetch_data.prefetch_buffer) == 0: - prefetch_data.run_prefetcher = False else: # Buffer is full, waiting for main thread to consume items # TODO: Calculate sleep interval based on previous consumption speed time.sleep(PRODUCER_SLEEP_INTERVAL) - time.sleep(PRODUCER_SLEEP_INTERVAL) + # Sleep longer when this prefetcher thread is paused + time.sleep(PRODUCER_SLEEP_INTERVAL * 10) def __iter__(self): try: prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) self.prefetch_data = prefetch_data thread = threading.Thread(target=PrefetcherIterDataPipe.thread_worker, args=(prefetch_data,), daemon=True) + thread.start() self.thread = thread - self.thread.start() - while prefetch_data.run_prefetcher: + # Lazily import to prevent circular import + from torchdata.dataloader2 import communication + + while not prefetch_data.stop_iteration or len(prefetch_data.prefetch_buffer) > 0: if len(prefetch_data.prefetch_buffer) > 0: - item = prefetch_data.prefetch_buffer.popleft() - if isinstance(item, Exception): - prefetch_data.run_prefetcher = False - raise item - yield item + data = prefetch_data.prefetch_buffer.popleft() + if isinstance(data, Exception): + if isinstance(data, (StopIteration, communication.iter.TerminateRequired)): + break + raise data + yield data else: - # TODO: Calculate sleep interval based on previous availability speed - if not prefetch_data.stop_iteration: - time.sleep(CONSUMER_SLEEP_INTERVAL) - else: - prefetch_data.run_prefetcher = False + time.sleep(CONSUMER_SLEEP_INTERVAL) finally: prefetch_data.run_prefetcher = False prefetch_data.stop_iteration = True @@ -138,10 +132,105 @@ def reset(self): def pause(self): if self.thread is not None: + assert self.prefetch_data is not None self.prefetch_data.run_prefetcher = False def resume(self): if self.thread is not None and ( not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0 ): + assert self.prefetch_data is not None self.prefetch_data.run_prefetcher = True + + +@functional_datapipe("pin_memory") +class PinMemoryIterDataPipe(PrefetcherIterDataPipe): + r""" + Prefetches one element from the source DataPipe and moves it to pinned memory (functional name: ``pin_memory``). + When used with ``MultiProcessingReadingService``, this DataPipe would be kept in the main process to prevent + duplicated CUDA context creation. + + Args: + source_datapipe: IterDataPipe from which samples are moved to pinned memory. + device: The device to pin samples. + pin_memory_fn: Optional callable function to move data to pinned memory. + A ``pin_memory_fn`` to handle general objects is provided by default. + + Example: + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(file_paths).open_files().readlines().map(tokenize_fn).pin_memory() + """ + + def __init__(self, source_datapipe, device=None, pin_memory_fn=pin_memory_fn): + if not torch.cuda.is_available(): + raise RuntimeError("``pin_memory`` can only be used when CUDA is available.") + # TODO: Add support for dynamic buffer based on the available size of pinned memory + super().__init__(source_datapipe, buffer_size=2) + if device is None: + device = torch.cuda.current_device() + self.device = device + self.pin_memory_fn = pin_memory_fn + + def is_replicable(self) -> bool: + return False + + @staticmethod + def thread_worker(prefetch_data: _PrefetchData, pin_memory_fn, device): # type: ignore[override] + itr = iter(prefetch_data.source_datapipe) + while not prefetch_data.stop_iteration: + # Run if not paused + while prefetch_data.run_prefetcher: + if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size: + try: + item = pin_memory_fn(next(itr), device) + prefetch_data.prefetch_buffer.append(item) + except Exception as e: + prefetch_data.run_prefetcher = False + prefetch_data.stop_iteration = True + prefetch_data.prefetch_buffer.append(e) + else: # Buffer is full, waiting for main thread to consume items + # TODO: Calculate sleep interval based on previous consumption speed + time.sleep(PRODUCER_SLEEP_INTERVAL) + # Sleep longer when this prefetcher thread is paused + time.sleep(PRODUCER_SLEEP_INTERVAL * 10) + + def __iter__(self): + try: + prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) + self.prefetch_data = prefetch_data + thread = threading.Thread( + target=PinMemoryIterDataPipe.thread_worker, + args=(prefetch_data, self.pin_memory_fn, self.device), + daemon=True, + ) + thread.start() + self.thread = thread + + # Lazily import to prevent circular import + from torchdata.dataloader2 import communication + + while not prefetch_data.stop_iteration or len(prefetch_data.prefetch_buffer) > 0: + if len(prefetch_data.prefetch_buffer) > 0: + data = prefetch_data.prefetch_buffer.popleft() + if isinstance(data, Exception): + if isinstance(data, (StopIteration, communication.iter.TerminateRequired)): + break + raise data + yield data + else: + time.sleep(CONSUMER_SLEEP_INTERVAL) + finally: + prefetch_data.run_prefetcher = False + prefetch_data.stop_iteration = True + thread.join() + + def __getstate__(self): + state = super().__getstate__() + state["pin_memory_fn"] = self.pin_memory_fn + state["device"] = self.device + return state + + def __setstate__(self, state): + super().__setstate__(state) + self.pin_memory_fn = state["pin_memory_fn"] + self.device = state["device"] diff --git a/torchdata/datapipes/utils/__init__.py b/torchdata/datapipes/utils/__init__.py index c74e8f702..2776af79f 100644 --- a/torchdata/datapipes/utils/__init__.py +++ b/torchdata/datapipes/utils/__init__.py @@ -6,7 +6,8 @@ from torch.utils.data.datapipes.utils.common import StreamWrapper -from ._visualization import to_graph -from .janitor import janitor +from torchdata.datapipes.utils._visualization import to_graph +from torchdata.datapipes.utils.janitor import janitor +from torchdata.datapipes.utils.pin_memory import pin_memory_fn -__all__ = ["StreamWrapper", "janitor", "to_graph"] +__all__ = ["StreamWrapper", "janitor", "pin_memory_fn", "to_graph"] diff --git a/torchdata/datapipes/utils/pin_memory.py b/torchdata/datapipes/utils/pin_memory.py new file mode 100644 index 000000000..9803faec9 --- /dev/null +++ b/torchdata/datapipes/utils/pin_memory.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import collections + + +def pin_memory_fn(data, device=None): + r""" + Utility function to move data to pinned memory. If special treatment is needed to move + the input data to pinned memory, please attach a ``pin_memory`` method to the expected + data class. + """ + if hasattr(data, "pin_memory"): # Including torch.Tensor + return data.pin_memory(device) + elif isinstance(data, (str, bytes)): + return data + elif isinstance(data, collections.abc.Mapping): + pinned_data = {k: pin_memory_fn(sample, device) for k, sample in data.items()} + try: + return type(data)(**pinned_data) + except TypeError: + # The mapping type may not support `__init__(iterable)`. + return pinned_data + elif isinstance(data, collections.abc.Sequence): + pinned_data = [pin_memory_fn(sample, device) for sample in data] # type: ignore[assignment] + try: + type(data)(*pinned_data) + except TypeError: + # The sequence type may not support `__init__(iterable)` (e.g., `range`). + return pinned_data + else: + return data