Skip to content

Commit

Permalink
Improved random seed configuration for Lhotse dataloaders with docs (N…
Browse files Browse the repository at this point in the history
…VIDIA#9001)

* Improving RNG seeding with Lhotse dataloading

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Fix

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Add documentation about random seeds

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Add doc about managing random seed

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Apply suggestions from code review

Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com>
Signed-off-by: Piotr Żelasko <petezor@gmail.com>

---------

Signed-off-by: Piotr Żelasko <petezor@gmail.com>
Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com>
  • Loading branch information
pzelasko and erastorgueva-nv authored Apr 26, 2024
1 parent bf2f6f0 commit d6e8ff1
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 23 deletions.
48 changes: 48 additions & 0 deletions docs/source/asr/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,54 @@ For multi-dataset setups, one may provide multiple manifests and even their weig
bucket_duration_bins=[1.91,3.02,3.56,...
<other diagnostic information about the dataset>
Seeds and randomness
~~~~~~~~~~~~~~~~~~~~
In Lhotse dataloading configuration we have two parameters controlling randomness: ``seed`` and ``shard_seed``.
Both of them can be either set to a fixed number, or one of two string options ``"randomized"`` and ``"trng"``.
Their roles are:
* ``seed`` is the base random seed, and is one of several factors used to initialize various RNGs participating in dataloading.
* ``shard_seed`` controls the shard randomization strategy in distributed data parallel setups when using sharded tarred datasets.
Below are the typical examples of configuration with an explanation of the expected outcome.
Case 1 (default): ``seed=<int>`` and ``shard_seed="trng"``:
* The ``trng`` setting discards ``seed`` and causes the actual random seed to be drawn using OS's true RNG. Each node/GPU/dataloading worker draws its own unique random seed when it first needs it.
* Each node/GPU/dataloading worker yields data in a different order (no mini-batch duplication).
* On each training script run, the order of dataloader examples are **different**.
* Since the random seed is unpredictable, the exact dataloading order is not replicable.
Case 2: ``seed=<int>`` and ``shard_seed="randomized"``:
* The ``randomized`` setting uses ``seed`` along with DDP ``rank`` and dataloading ``worker_id`` to set a unique but deterministic random seed in each dataloading process across all GPUs.
* Each node/GPU/dataloading worker yields data in a different order (no mini-batch duplication).
* On each training script run, the order of dataloader examples are **identical** as long as ``seed`` is the same.
* This setup guarantees 100% dataloading reproducibility.
* Resuming training without changing of the ``seed`` value will cause the model to train on data it has already seen. For large data setups, not managing the ``seed`` may cause the model to never be trained on a majority of data. This is why this mode is not the default.
* If you're combining DDP with model parallelism techniques (Tensor Parallel, Pipeline Parallel, etc.) you need to use ``shard_seed="randomized"``. Using ``"trng"`` will cause different model parallel ranks to desynchronize and cause a deadlock.
* Generally the seed can be managed by the user by providing a different value each time the training script is launched. For example, for most models the option to override would be ``model.train_ds.seed=<value>``. If you're launching multiple tasks queued one after another on a grid system, you can generate a different random seed for each task, e.g. on most Unix systems ``RSEED=$(od -An -N4 -tu4 < /dev/urandom | tr -d ' ')`` would generate a random uint32 number that can be provided as the seed.
Other, more exotic configurations:
* With ``shard_seed=<int>``, all dataloading workers will yield the same results. This is only useful for unit testing and maybe debugging.
* With ``seed="trng"``, the base random seed itself will be drawn using a TRNG. It will be different on each GPU training process. This setting is not recommended.
* With ``seed="randomized"``, the base random seed is set to Python's global RNG seed. It might be different on each GPU training process. This setting is not recommended.
Preparing Text-Only Data for Hybrid ASR-TTS Models
--------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.common import tokenizers
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config
from nemo.collections.common.metrics import GlobalAverageLossMetric
from nemo.collections.common.parts import transformer_weights_init
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
Expand Down
31 changes: 18 additions & 13 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@
IterableDatasetWrapper,
make_worker_init_fn,
)
from lhotse.dataset.dataloading import resolve_seed
from lhotse.dataset.sampling.base import SamplingConstraint, TimeConstraint, TokenConstraint
from lhotse.lazy import LazyFlattener
from lhotse.utils import fastcopy
from lhotse.utils import fastcopy, fix_random_seed
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper
from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.utils import logging


Expand Down Expand Up @@ -87,7 +86,7 @@ class LhotseDataLoadingConfig:
sample_rate: int = 16000
min_duration: float | None = -1
max_duration: float | None = float("inf")
seed: int | str = "randomized" # int | "randomized" | "trng"; the latter two are lazily resolved by Lhotse in dloading worker processes
seed: int | str = 0
num_workers: int = 0
pin_memory: bool = False

Expand Down Expand Up @@ -123,11 +122,7 @@ class LhotseDataLoadingConfig:


def get_lhotse_dataloader_from_config(
config: DictConfig,
global_rank: int,
world_size: int,
dataset: torch.utils.data.Dataset,
tokenizer: TokenizerSpec | TokenizerWrapper = None,
config: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, tokenizer=None,
) -> torch.utils.data.DataLoader:
"""
Set up a Lhotse training dataloder.
Expand All @@ -154,6 +149,10 @@ def get_lhotse_dataloader_from_config(

config = make_structured_with_schema_warnings(config)

# First, resolve the random seed in case a string value was provided.
seed = resolve_seed(config.seed)
fix_random_seed(seed)

# 1. Load a manifest as a Lhotse CutSet.
cuts, is_tarred = read_cutset_from_config(config)

Expand All @@ -167,6 +166,8 @@ def get_lhotse_dataloader_from_config(
assert (
tokenizer is not None
), "You must pass a tokenizer to `get_lhotse_dataloader_from_config` in order to read text-only datasets (enabled via use_multimodal_dataloading)"
from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper

if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
# Note this code can also pre-tokenize the text in cuts, but for now we disable it with apply_fn.
Expand All @@ -177,7 +178,11 @@ def get_lhotse_dataloader_from_config(
if config.noise_path is not None:
noise = CutSet.from_file(config.noise_path)
cuts = cuts.mix(
cuts=noise, snr=config.noise_snr, mix_prob=config.noise_mix_prob, seed="trng", random_mix_offset=True
cuts=noise,
snr=config.noise_snr,
mix_prob=config.noise_mix_prob,
seed=config.shard_seed,
random_mix_offset=True,
)

# 2.b. On-the-fly speed perturbation.
Expand Down Expand Up @@ -235,7 +240,7 @@ def get_lhotse_dataloader_from_config(
shuffle=config.shuffle,
drop_last=config.drop_last,
shuffle_buffer_size=config.shuffle_buffer_size,
seed=config.seed,
seed=config.shard_seed,
num_buckets=config.num_buckets,
duration_bins=config.bucket_duration_bins,
num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate,
Expand All @@ -257,7 +262,7 @@ def get_lhotse_dataloader_from_config(
shuffle=config.shuffle,
drop_last=config.drop_last,
shuffle_buffer_size=config.shuffle_buffer_size,
seed=config.seed,
seed=config.shard_seed,
rank=0 if is_tarred else global_rank,
world_size=1 if is_tarred else world_size,
)
Expand Down Expand Up @@ -289,7 +294,7 @@ def get_lhotse_dataloader_from_config(
# This together with infinite datasets removes the need to split data across nodes/workers.
dloader_kwargs = dict(
dataset=IterableDatasetWrapper(dataset=dataset, sampler=sampler),
worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size),
worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size, seed=seed),
persistent_workers=config.num_workers > 0, # helps Lhotse Shar maintain shuffling state
)
else:
Expand Down
15 changes: 6 additions & 9 deletions tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def test_dataloader_from_nemo_manifest(nemo_manifest_path: Path):
assert b["audio"].shape[0] == b["audio_lens"].shape[0] == 1


class _Identity:
def __getitem__(self, cuts):
return cuts


def test_dataloader_from_nemo_manifest_has_custom_fields(nemo_manifest_path: Path):
config = OmegaConf.create(
{
Expand All @@ -356,11 +361,7 @@ def test_dataloader_from_nemo_manifest_has_custom_fields(nemo_manifest_path: Pat
}
)

class _IdentityDataset:
def __getitem__(self, cuts):
return cuts

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=_IdentityDataset())
dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=_Identity())

batch = next(iter(dl))
for cut in batch:
Expand Down Expand Up @@ -852,10 +853,6 @@ def test_lhotse_cuts_resolve_relative_paths(tmp_path: Path):
{"cuts_path": cuts_path, "sample_rate": 16000, "use_lhotse": True, "num_workers": 0, "batch_size": 2,}
)

class _Identity(torch.utils.data.Dataset):
def __getitem__(self, x):
return x

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=_Identity())

batches = [batch for batch in dl]
Expand Down
162 changes: 162 additions & 0 deletions tests/collections/common/test_lhotse_multirank_rng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from io import BytesIO
from pathlib import Path

import pytest
from lhotse import CutSet
from lhotse.serialization import load_jsonl, save_to_jsonl
from lhotse.shar.writers import JsonlShardWriter, TarWriter
from lhotse.testing.dummies import DummyManifest
from omegaconf import OmegaConf

from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config


class _Identity:
def __getitem__(self, cuts):
return cuts


@pytest.fixture(scope="session")
def cutset_path(tmp_path_factory) -> Path:
"""10 utterances of length 1s as a Lhotse CutSet."""
cuts = DummyManifest(CutSet, begin_id=0, end_id=10, with_data=True)
for c in cuts:
c.features = None
c.custom = None
c.supervisions[0].custom = None

tmp_path = tmp_path_factory.mktemp("data")
p = tmp_path / "cuts.jsonl.gz"
pa = tmp_path / "audio"
cuts.save_audios(pa).to_file(p)
return p


@pytest.fixture(scope="session")
def nemo_manifest_path(cutset_path: Path):
"""10 utterances of length 1s as a NeMo manifest."""
nemo = []
for idx, c in enumerate(CutSet.from_file(cutset_path)):
nemo.append(
{"audio_filepath": c.recording.sources[0].source, "text": f"irrelevant-{idx}", "duration": c.duration,}
)
p = cutset_path.parent / "nemo_manifest.json"
save_to_jsonl(nemo, p)
return p


@pytest.fixture(scope="session")
def nemo_tarred_manifest_path(nemo_manifest_path: Path) -> tuple[str, str]:
"""5 shards, each with 2 utterances."""
root = nemo_manifest_path.parent / "nemo_tar"
root.mkdir(exist_ok=True)
with TarWriter(f"{root}/audios_%01d.tar", shard_size=2) as tar_writer, JsonlShardWriter(
f"{root}/manifest_%01d.jsonl", shard_size=2
) as mft_writer:
for idx, d in enumerate(load_jsonl(nemo_manifest_path)):
p = d["audio_filepath"]
name = Path(p).name
with open(p, "rb") as f:
tar_writer.write(name, BytesIO(f.read()))
mft_writer.write({**d, "audio_filepath": name, "shard_id": idx // 2})
return f"{root}/manifest__OP_0..4_CL_.jsonl", f"{root}/audios__OP_0..4_CL_.tar"


def test_dataloader_multiple_ranks_deterministic_rng(nemo_tarred_manifest_path: tuple[str, str]):
json_mft, tar_mft = nemo_tarred_manifest_path
config = OmegaConf.create(
{
"manifest_filepath": json_mft,
"tarred_audio_filepaths": tar_mft,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 1,
# lhotse specific
"use_bucketing": True,
"num_buckets": 2,
"drop_last": False,
"batch_duration": 4.0, # seconds
"quadratic_duration": 15.0, # seconds
"shuffle_buffer_size": 10,
"bucket_buffer_size": 100,
"seed": 0,
"shard_seed": "randomized",
}
)

# Data parallel, rank 0
dp0 = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity())

# Data parallel, rank 0 copy (is the iteration deterministic? -> yes)
dp0_cpy = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity(),)

# Data parallel, rank 0, incremented seed (paranoia mode: does the iteration order change with the seed? -> yes)
config2 = config.copy()
config2["seed"] = config2["seed"] + 1
dp0_incrseed = get_lhotse_dataloader_from_config(config=config2, global_rank=0, world_size=2, dataset=_Identity(),)

# Data parallel, rank 1 (is data different on each DP rank? -> yes)
dp1 = get_lhotse_dataloader_from_config(config=config, global_rank=1, world_size=2, dataset=_Identity())

dloaders = zip(*[iter(dl) for dl in (dp0, dp0_cpy, dp0_incrseed, dp1)])

for i in range(5):
b0, b0_cpy, b0_incrseed, b1 = next(dloaders)
assert b0 == b0_cpy
assert b0 != b1
assert b0_incrseed != b1
assert b0 != b0_incrseed


def test_dataloader_multiple_ranks_trng(nemo_tarred_manifest_path: tuple[str, str]):
"""
This test is the same as ``test_dataloader_multiple_ranks_deterministic_rng``,
except that we set ``shard_seed="trng"`` which causes the seed to be lazily
resolved in subprocesses (resolved => being drawn using OS's TRNG).
Therefore, we don't expect any reproducibility.
"""
json_mft, tar_mft = nemo_tarred_manifest_path
config = OmegaConf.create(
{
"manifest_filepath": json_mft,
"tarred_audio_filepaths": tar_mft,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 1,
# lhotse specific
"use_bucketing": True,
"num_buckets": 2,
"drop_last": False,
"batch_duration": 4.0, # seconds
"quadratic_duration": 15.0, # seconds
"shuffle_buffer_size": 10,
"bucket_buffer_size": 100,
"seed": 0,
"shard_seed": "trng",
}
)

# Data parallel, rank 0
dp0 = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity())

# Data parallel, rank 0 copy (is the iteration deterministic? -> no, trng)
dp0_cpy = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity(),)

# Data parallel, rank 0, incremented seed (paranoia mode: does the iteration order change with the seed? -> yes)
config2 = config.copy()
config2["seed"] = config2["seed"] + 1
dp0_incrseed = get_lhotse_dataloader_from_config(config=config2, global_rank=0, world_size=2, dataset=_Identity(),)

# Data parallel, rank 1 (is data different on each DP rank? -> yes)
dp1 = get_lhotse_dataloader_from_config(config=config, global_rank=1, world_size=2, dataset=_Identity())

dloaders = zip(*[iter(dl) for dl in (dp0, dp0_cpy, dp0_incrseed, dp1)])

for i in range(5):
b0, b0_cpy, b0_incrseed, b1 = next(dloaders)
assert b0 != b0_cpy
assert b0 != b1
assert b0_incrseed != b1
assert b0 != b0_incrseed

0 comments on commit d6e8ff1

Please sign in to comment.