Skip to content

Commit

Permalink
Custom DataLoader with two levels of subprocess workers (#343)
Browse files Browse the repository at this point in the history
* Enable using a worker pool inside of the InputStrategy

* A draft of a custom dataloader that allows two levels of process pools

* Transfer the dataset object to a process only once

* Add some description

* Tests for LhotseDataLoader, rename InputStrategy -> BatchIO

* Remove type annotation that causes a syntax error

* Fix test

* Require Python 3.7 or higher to use LhotseDataLoader

* Disable LhotseDataLoader test for Python < 3.7
  • Loading branch information
pzelasko authored Jul 23, 2021
1 parent 3c6c0ae commit e95d134
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 30 deletions.
13 changes: 9 additions & 4 deletions docs/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ A typical Lhotse's dataset API usage might look like this:
for batch in dloader:
... # process data
Pre-computed vs. on-the-fly: input strategies
---------------------------------------------
Batch I/O: pre-computed vs. on-the-fly features
-----------------------------------------------

Depending on the experimental setup and infrastructure, it might be more convenient to either pre-compute and store features like filter-bank energies for later use (as traditionally done in Kaldi/ESPnet/Espresso toolkits), or compute them dynamically during training ("on-the-fly").
Lhotse supports both modes of computation by introducing a class called :class:`~lhotse.dataset.input_strategies.InputStrategy`.
Lhotse supports both modes of computation by introducing a class called :class:`~lhotse.dataset.input_strategies.BatchIO`.
It is accepted as an argument in most dataset classes, and defaults to :class:`~lhotse.dataset.input_strategies.PrecomputedFeatures`.
Other available choices are :class:`~lhotse.dataset.input_strategies.AudioSamples` for working with waveforms directly,
and :class:`~lhotse.dataset.input_strategies.OnTheFlyFeatures`, which wraps a :class:`~lhotse.features.base.FeatureExtractor` and applies it to a batch of recordings. These strategies automatically pad and collate the inputs, and provide information about the original signal lengths: as a number of frames/samples, binary mask, or start-end frame/sample pairs.
Expand Down Expand Up @@ -121,4 +121,9 @@ These transforms work directly on batches of collated feature matrices (or possi
Collation utilities for building custom Datasets
------------------------------------------------

.. automodule:: lhotse.dataset.collation
.. automodule:: lhotse.dataset.collation

Experimental: LhotseDataLoader
------------------------------

.. autoclass:: lhotse.dataset.dataloading.LhotseDataLoader
3 changes: 2 additions & 1 deletion lhotse/augmentation/torchaudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

if not during_docs_build() and _version(torchaudio.__version__) < _version('0.7'):
warnings.warn('Torchaudio SoX effects chains are only introduced in version 0.7 - '
'please upgrade your PyTorch to 1.7+ and torchaudio to 0.7+ to use them.')
'please upgrade your PyTorch to 1.7.1 and torchaudio to 0.7.2 (or higher) '
'to use them.')


@dataclass
Expand Down
44 changes: 36 additions & 8 deletions lhotse/dataset/collation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Iterable, List, Tuple, Union
from concurrent.futures import Executor
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.nn import CrossEntropyLoss

from lhotse import CutSet
from lhotse.cut import MixedCut
from lhotse.cut import Cut, MixedCut


class TokenCollater:
Expand Down Expand Up @@ -101,7 +102,8 @@ def inverse(self, tokens_batch: torch.LongTensor, tokens_lens: torch.IntTensor)

def collate_features(
cuts: CutSet,
pad_direction: str = 'right'
pad_direction: str = 'right',
executor: Optional[Executor] = None,
) -> Tuple[torch.Tensor, torch.IntTensor]:
"""
Load features for all the cuts and return them as a batch in a torch tensor.
Expand All @@ -110,21 +112,28 @@ def collate_features(
:param cuts: a :class:`CutSet` used to load the features.
:param pad_direction: where to apply the padding (``right``, ``left``, or ``both``).
:param executor: an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided,
we will use it to read the features concurrently.
:return: a tuple of tensors ``(features, features_lens)``.
"""
assert all(cut.has_features for cut in cuts)
features_lens = torch.tensor([cut.num_frames for cut in cuts], dtype=torch.int)
cuts = maybe_pad(cuts, num_frames=max(features_lens).item(), direction=pad_direction)
first_cut = next(iter(cuts))
features = torch.empty(len(cuts), first_cut.num_frames, first_cut.num_features)
for idx, cut in enumerate(cuts):
features[idx] = torch.from_numpy(cut.load_features())
if executor is None:
for idx, cut in enumerate(cuts):
features[idx] = _read_features(cut)
else:
for idx, example_features in enumerate(executor.map(_read_features, cuts)):
features[idx] = example_features
return features, features_lens


def collate_audio(
cuts: CutSet,
pad_direction: str = 'right'
pad_direction: str = 'right',
executor: Optional[Executor] = None,
) -> Tuple[torch.Tensor, torch.IntTensor]:
"""
Load audio samples for all the cuts and return them as a batch in a torch tensor.
Expand All @@ -133,15 +142,21 @@ def collate_audio(
:param cuts: a :class:`CutSet` used to load the audio samples.
:param pad_direction: where to apply the padding (``right``, ``left``, or ``both``).
:param executor: an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided,
we will use it to read audio concurrently.
:return: a tuple of tensors ``(audio, audio_lens)``.
"""
assert all(cut.has_recording for cut in cuts)
audio_lens = torch.tensor([cut.num_samples for cut in cuts], dtype=torch.int32)
cuts = maybe_pad(cuts, num_samples=max(audio_lens).item(), direction=pad_direction)
first_cut = next(iter(cuts))
audio = torch.empty(len(cuts), first_cut.num_samples)
for idx, cut in enumerate(cuts):
audio[idx] = torch.from_numpy(cut.load_audio()[0])
if executor is None:
for idx, cut in enumerate(cuts):
audio[idx] = _read_audio(cut)
else:
for idx, example_audio in enumerate(executor.map(_read_audio, cuts)):
audio[idx] = example_audio
return audio, audio_lens


Expand Down Expand Up @@ -254,3 +269,16 @@ def maybe_pad(
num_samples=num_samples,
direction=direction
)


"""
Helper functions to dispatch jobs to the concurrent executors.
"""


def _read_audio(cut: Cut) -> torch.Tensor:
return torch.from_numpy(cut.load_audio()[0])


def _read_features(cut: Cut) -> torch.Tensor:
return torch.from_numpy(cut.load_features())
105 changes: 105 additions & 0 deletions lhotse/dataset/dataloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import platform
from collections import deque
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import get_context
from typing import Any, Dict, List

import torch.utils.data

from lhotse.dataset.sampling import CutSampler


class LhotseDataLoader:
"""
A simplified ``DataLoader`` implementation that relies on a ``ProcessPoolExecutor``.
The main difference between this and ``torch.utils.data.DataLoader`` is that
:class:`.LhotseDataLoader` allows to launch subprocesses inside of its workers.
This is useful for working with dataset classes which perform dynamic batching
and need to perform concurrent I/O to read all the necessary data from disk/network.
.. note:: :class:`.LhotseDataLoader` does not support ``num_workers=0``.
.. warning:: :class:`.LhotseDataLoader` is experimental and not guaranteed to work
correctly across all possible edge cases related to subprocess worker termination.
If you experience stability problems, contact us or use a standard ``DataLoader``
instead.
.. warning:: :class:`.LhotseDataLoader` requires Python >= 3.7.
"""

def __init__(
self,
dataset: torch.utils.data.Dataset,
sampler: CutSampler,
num_workers: int = 1,
prefetch_factor: int = 2,
) -> None:
from packaging.version import parse as _version

if _version(platform.python_version()) < _version("3.7"):
raise RuntimeError("LhotseDataLoader requires Python version at least 3.7")
assert num_workers >= 1
assert prefetch_factor >= 1
self.dataset = dataset
self.sampler = sampler
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
# Mutable state
self._iter = None
self._futures = deque([])
# Start the worker processes. The initializer receives the dataset object
# from the main process and caches it globally, so that it can be re-used
# for subsequent tasks sent to the worker. This helps avoid excessive
# communication between the processes.
self.pool = ProcessPoolExecutor(
num_workers,
initializer=_init_worker,
initargs=(dataset,),
mp_context=get_context("spawn"),
)

def __iter__(self) -> "LhotseDataLoader":
"""Prepares the sampler for iteration and schedules initial tasks to the workers."""
self._iter = iter(self.sampler)
for _ in range(self.prefetch_factor * self.num_workers):
self._schedule_one()
return self

def _schedule_one(self) -> None:
"""Submits a task and stores the future for results retrieval."""
if self._iter is not None:
try:
self._futures.append(self.pool.submit(_get_item, next(self._iter)))
except StopIteration:
self._iter = None

def _retrieve_one(self) -> Dict[str, Any]:
"""Retrieves the result from the earliest submitted task."""
if self._futures:
return self._futures.popleft().result()
raise StopIteration()

def __next__(self) -> Dict[str, Any]:
"""Submits a new batch to process and then retrieves and returns a completed batch."""
self._schedule_one()
return self._retrieve_one()


def _init_worker(dataset: torch.utils.data.Dataset) -> None:
"""
Stores the dataset in the global state of the process -- this is safe because
the process is initialized only once and used for unique dataset in its life span.
"""
global _GLOBAL_DATASET_CACHE
_GLOBAL_DATASET_CACHE = dataset


def _get_item(cut_ids: List[str]) -> Dict[str, Any]:
"""
Queries the globally cached dataset to retrieve a batch. Has to be run
inside a worker process that was initialized with :meth:`._init_worker`.
"""
return _GLOBAL_DATASET_CACHE[cut_ids]


_GLOBAL_DATASET_CACHE = None
49 changes: 38 additions & 11 deletions lhotse/dataset/input_strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from typing import Callable, Dict, List, Tuple, Optional
from concurrent.futures import ProcessPoolExecutor
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Tuple

import torch

Expand All @@ -9,16 +11,24 @@
from lhotse.utils import compute_num_frames, ifnone, supervision_to_frames, supervision_to_samples


class InputStrategy:
class BatchIO:
"""
Converts a :class:`CutSet` into a collated batch of audio representations.
These representations can be e.g. audio samples or features.
They might also be single or multi channel.
This is a base class that only defines the interface.
All InputStrategies support the ``executor`` parameter in the constructor.
It allows to pass a ``ThreadPoolExecutor`` or a ``ProcessPoolExecutor``
to parallelize reading audio/features from wherever they are stored.
Note that this approach is incompatible with specifying the ``num_workers``
to ``torch.utils.data.DataLoader``, but in some instances may be faster.
.. note:: This is a base class that only defines the interface.
.. automethod:: __call__
"""
def __init__(self, num_workers: int = 0) -> None:
self.num_workers = num_workers

def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.IntTensor]:
"""Returns a tensor with collated input signals, and a tensor of length of each signal before padding."""
Expand Down Expand Up @@ -68,7 +78,7 @@ def supervision_masks(self, cuts: CutSet) -> torch.Tensor:
raise NotImplementedError()


class PrecomputedFeatures(InputStrategy):
class PrecomputedFeatures(BatchIO):
"""
:class:`InputStrategy` that reads pre-computed features, whose manifests
are attached to cuts, from disk.
Expand All @@ -84,7 +94,7 @@ def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.IntTensor]:
The returned shape is ``(B, T, F) => (batch_size, num_frames, num_features)``.
:return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding."""
return collate_features(cuts)
return collate_features(cuts, executor=_get_executor(self.num_workers))

def supervision_intervals(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -123,7 +133,7 @@ def supervision_masks(self, cuts: CutSet, use_alignment_if_exists: Optional[str]
return collate_vectors([cut.supervisions_feature_mask(use_alignment_if_exists=use_alignment_if_exists) for cut in cuts])


class AudioSamples(InputStrategy):
class AudioSamples(BatchIO):
"""
:class:`InputStrategy` that reads single-channel recordings, whose manifests
are attached to cuts, from disk (or other audio source).
Expand All @@ -132,15 +142,14 @@ class AudioSamples(InputStrategy):
.. automethod:: __call__
"""

def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.IntTensor]:
"""
Reads the audio samples from recordings on disk/other storage.
The returned shape is ``(B, T) => (batch_size, num_samples)``.
:return: a tensor with collated audio samples, and a tensor of ``num_samples`` of each cut before padding.
"""
return collate_audio(cuts)
return collate_audio(cuts, executor=_get_executor(self.num_workers))

def supervision_intervals(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -180,7 +189,7 @@ def supervision_masks(self, cuts: CutSet, use_alignment_if_exists: Optional[str]
return collate_vectors([cut.supervisions_audio_mask(use_alignment_if_exists=use_alignment_if_exists) for cut in cuts])


class OnTheFlyFeatures(InputStrategy):
class OnTheFlyFeatures(BatchIO):
"""
:class:`InputStrategy` that reads single-channel recordings, whose manifests
are attached to cuts, from disk (or other audio source).
Expand All @@ -199,7 +208,8 @@ class OnTheFlyFeatures(InputStrategy):
def __init__(
self,
extractor: FeatureExtractor,
wave_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None
wave_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
num_workers: int = 0,
) -> None:
"""
OnTheFlyFeatures' constructor.
Expand All @@ -208,6 +218,7 @@ def __init__(
:param wave_transforms: an optional list of transforms applied on the batch of audio
waveforms collated into a single tensor, right before the feature extraction.
"""
super().__init__(num_workers=num_workers)
self.extractor = extractor
self.wave_transforms = ifnone(wave_transforms, [])

Expand All @@ -219,7 +230,7 @@ def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.IntTensor]:
:return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
"""
audio, _ = collate_audio(cuts)
audio, _ = collate_audio(cuts, executor=_get_executor(self.num_workers))

for tfnm in self.wave_transforms:
audio = tfnm(audio)
Expand Down Expand Up @@ -288,3 +299,19 @@ def supervision_masks(self, cuts: CutSet, use_alignment_if_exists: Optional[str]
) for cut in cuts
]
)


@lru_cache(maxsize=1)
def _get_executor(max_workers: int = 0) -> Optional[ProcessPoolExecutor]:
"""
This function caches a process pool in the global state of a given process.
It's useful for keeping a process pool alive across different invocations within the
same process for efficiency.
We intend it to be used for efficient data reads withing a task executed in a
parent process pool.
"""
if max_workers <= 0:
return None
return ProcessPoolExecutor(max_workers=max_workers)


4 changes: 2 additions & 2 deletions lhotse/dataset/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import InputStrategy, PrecomputedFeatures
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import ifnone


Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
return_cuts: bool = False,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: InputStrategy = PrecomputedFeatures(),
input_strategy: BatchIO = PrecomputedFeatures(),
check_inputs: bool = True
):
"""
Expand Down
Loading

0 comments on commit e95d134

Please sign in to comment.