-
Notifications
You must be signed in to change notification settings - Fork 221
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Custom DataLoader with two levels of subprocess workers (#343)
* Enable using a worker pool inside of the InputStrategy * A draft of a custom dataloader that allows two levels of process pools * Transfer the dataset object to a process only once * Add some description * Tests for LhotseDataLoader, rename InputStrategy -> BatchIO * Remove type annotation that causes a syntax error * Fix test * Require Python 3.7 or higher to use LhotseDataLoader * Disable LhotseDataLoader test for Python < 3.7
- Loading branch information
Showing
10 changed files
with
254 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import platform | ||
from collections import deque | ||
from concurrent.futures import ProcessPoolExecutor | ||
from multiprocessing import get_context | ||
from typing import Any, Dict, List | ||
|
||
import torch.utils.data | ||
|
||
from lhotse.dataset.sampling import CutSampler | ||
|
||
|
||
class LhotseDataLoader: | ||
""" | ||
A simplified ``DataLoader`` implementation that relies on a ``ProcessPoolExecutor``. | ||
The main difference between this and ``torch.utils.data.DataLoader`` is that | ||
:class:`.LhotseDataLoader` allows to launch subprocesses inside of its workers. | ||
This is useful for working with dataset classes which perform dynamic batching | ||
and need to perform concurrent I/O to read all the necessary data from disk/network. | ||
.. note:: :class:`.LhotseDataLoader` does not support ``num_workers=0``. | ||
.. warning:: :class:`.LhotseDataLoader` is experimental and not guaranteed to work | ||
correctly across all possible edge cases related to subprocess worker termination. | ||
If you experience stability problems, contact us or use a standard ``DataLoader`` | ||
instead. | ||
.. warning:: :class:`.LhotseDataLoader` requires Python >= 3.7. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: torch.utils.data.Dataset, | ||
sampler: CutSampler, | ||
num_workers: int = 1, | ||
prefetch_factor: int = 2, | ||
) -> None: | ||
from packaging.version import parse as _version | ||
|
||
if _version(platform.python_version()) < _version("3.7"): | ||
raise RuntimeError("LhotseDataLoader requires Python version at least 3.7") | ||
assert num_workers >= 1 | ||
assert prefetch_factor >= 1 | ||
self.dataset = dataset | ||
self.sampler = sampler | ||
self.num_workers = num_workers | ||
self.prefetch_factor = prefetch_factor | ||
# Mutable state | ||
self._iter = None | ||
self._futures = deque([]) | ||
# Start the worker processes. The initializer receives the dataset object | ||
# from the main process and caches it globally, so that it can be re-used | ||
# for subsequent tasks sent to the worker. This helps avoid excessive | ||
# communication between the processes. | ||
self.pool = ProcessPoolExecutor( | ||
num_workers, | ||
initializer=_init_worker, | ||
initargs=(dataset,), | ||
mp_context=get_context("spawn"), | ||
) | ||
|
||
def __iter__(self) -> "LhotseDataLoader": | ||
"""Prepares the sampler for iteration and schedules initial tasks to the workers.""" | ||
self._iter = iter(self.sampler) | ||
for _ in range(self.prefetch_factor * self.num_workers): | ||
self._schedule_one() | ||
return self | ||
|
||
def _schedule_one(self) -> None: | ||
"""Submits a task and stores the future for results retrieval.""" | ||
if self._iter is not None: | ||
try: | ||
self._futures.append(self.pool.submit(_get_item, next(self._iter))) | ||
except StopIteration: | ||
self._iter = None | ||
|
||
def _retrieve_one(self) -> Dict[str, Any]: | ||
"""Retrieves the result from the earliest submitted task.""" | ||
if self._futures: | ||
return self._futures.popleft().result() | ||
raise StopIteration() | ||
|
||
def __next__(self) -> Dict[str, Any]: | ||
"""Submits a new batch to process and then retrieves and returns a completed batch.""" | ||
self._schedule_one() | ||
return self._retrieve_one() | ||
|
||
|
||
def _init_worker(dataset: torch.utils.data.Dataset) -> None: | ||
""" | ||
Stores the dataset in the global state of the process -- this is safe because | ||
the process is initialized only once and used for unique dataset in its life span. | ||
""" | ||
global _GLOBAL_DATASET_CACHE | ||
_GLOBAL_DATASET_CACHE = dataset | ||
|
||
|
||
def _get_item(cut_ids: List[str]) -> Dict[str, Any]: | ||
""" | ||
Queries the globally cached dataset to retrieve a batch. Has to be run | ||
inside a worker process that was initialized with :meth:`._init_worker`. | ||
""" | ||
return _GLOBAL_DATASET_CACHE[cut_ids] | ||
|
||
|
||
_GLOBAL_DATASET_CACHE = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.