Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved random seed configuration for Lhotse dataloaders with docs #9001

Merged
merged 7 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/source/asr/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,52 @@ For multi-dataset setups, one may provide multiple manifests and even their weig
bucket_duration_bins=[1.91,3.02,3.56,...
<other diagnostic information about the dataset>


Seeds and randomness
~~~~~~~~~~~~~~~~~~~~

In Lhotse dataloading configuration we have two parameters controlling randomness: ``seed`` and ``shard_seed``.
Both of them can be either set to a fixed number, or one of two string options ``"randomized"`` and ``"trng"``.
Their roles are:

* ``seed`` is the base random seed, and is one of several factors used to initialize various RNGs participating in dataloading.

* ``shard_seed`` controls the shard randomization strategy in distributed data parallel setups when using sharded tarred datasets.
pzelasko marked this conversation as resolved.
Show resolved Hide resolved

Below are the typical examples of configuration with an explanation of the expected outcome.

Case 1 (default): ``seed=<int>`` and ``shard_seed="trng"``:

* The ``trng`` setting discards ``seed`` and causes the actual random seed to be drawn using OS's true RNG. Each node/GPU/dataloading worker draws its own unique random seed when it first needs it.

* Each node/GPU/dataloading worker yields data in a different order (no mini-batch duplication).

* On each training script run, the order of dataloader examples are **different**.

* Since the random seed is unpredictable, the exact dataloading order is not replicable.

Case 2: ``seed=<int>`` and ``shard_seed="randomized"``:

* The ``randomized`` setting uses ``seed`` along with DDP ``rank`` and dataloading ``worker_id`` to set a unique but deterministic random seed in each dataloading process across all GPUs.

* Each node/GPU/dataloading worker yields data in a different order (no mini-batch duplication).

* On each training script run, the order of dataloader examples are **identical** as long as ``seed`` is the same.

* This setup guarantees 100% dataloading reproducibility.

* Resuming training without changing of the ``seed`` value will cause the model to train on data it has already seen. For large data setups, not managing the ``seed`` may cause the model to never be trained on a majority of data. This is why this mode is not the default.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the recommended practice on managing the "seed" on large data setup?

Copy link
Collaborator Author

@pzelasko pzelasko Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally every time you resume, you'd provide a different value to model.train_ds.seed=<val>. A true-enough random seed can be obtained on most systems by reading /dev/urandom, e.g. uint32 seed: RSEED=$(od -An -N4 -tu4 < /dev/urandom | tr -d ' '). If you have some sort of "launcher script" that queues multiple jobs, this would be the right place to use this. Let me update the docs with this example.

Ideally we'd be able to automate this seed management thing by keeping some state in the checkpoints, but at this point it'd be a scope creep.

pzelasko marked this conversation as resolved.
Show resolved Hide resolved

* If you're combining DDP with model parallelism techniques (Tensor Parallel, Pipeline Parallel, etc.) you need to use ``shard_seed="randomized"``. Using ``"trng"`` will cause different model parallel ranks to desynchronize and cause a deadlock.

Other, more exotic configurations:

* With ``shard_seed=<int>``, all dataloading workers will yield the same results. This is only useful for unit testing and maybe debugging.

* With ``seed="trng"``, the base random seed itself will be drawn using a TRNG. It will be different on each GPU training process. This setting is not recommended.

* With ``seed="randomized"``, the base random seed is set to Python's global RNG seed. It might be different on each GPU training process. This setting is not recommended.
pzelasko marked this conversation as resolved.
Show resolved Hide resolved

Preparing Text-Only Data for Hybrid ASR-TTS Models
--------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.common import tokenizers
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config
from nemo.collections.common.metrics import GlobalAverageLossMetric
from nemo.collections.common.parts import transformer_weights_init
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
Expand Down
31 changes: 18 additions & 13 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@
IterableDatasetWrapper,
make_worker_init_fn,
)
from lhotse.dataset.dataloading import resolve_seed
from lhotse.dataset.sampling.base import SamplingConstraint, TimeConstraint, TokenConstraint
from lhotse.lazy import LazyFlattener
from lhotse.utils import fastcopy
from lhotse.utils import fastcopy, fix_random_seed
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper
from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.utils import logging


Expand Down Expand Up @@ -87,7 +86,7 @@ class LhotseDataLoadingConfig:
sample_rate: int = 16000
min_duration: float | None = -1
max_duration: float | None = float("inf")
seed: int | str = "randomized" # int | "randomized" | "trng"; the latter two are lazily resolved by Lhotse in dloading worker processes
seed: int | str = 0
num_workers: int = 0
pin_memory: bool = False

Expand Down Expand Up @@ -123,11 +122,7 @@ class LhotseDataLoadingConfig:


def get_lhotse_dataloader_from_config(
config: DictConfig,
global_rank: int,
world_size: int,
dataset: torch.utils.data.Dataset,
tokenizer: TokenizerSpec | TokenizerWrapper = None,
config: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, tokenizer=None,
) -> torch.utils.data.DataLoader:
"""
Set up a Lhotse training dataloder.
Expand All @@ -154,6 +149,10 @@ def get_lhotse_dataloader_from_config(

config = make_structured_with_schema_warnings(config)

# First, resolve the random seed in case a string value was provided.
seed = resolve_seed(config.seed)
fix_random_seed(seed)

# 1. Load a manifest as a Lhotse CutSet.
cuts, is_tarred = read_cutset_from_config(config)

Expand All @@ -167,6 +166,8 @@ def get_lhotse_dataloader_from_config(
assert (
tokenizer is not None
), "You must pass a tokenizer to `get_lhotse_dataloader_from_config` in order to read text-only datasets (enabled via use_multimodal_dataloading)"
from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper

if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
# Note this code can also pre-tokenize the text in cuts, but for now we disable it with apply_fn.
Expand All @@ -177,7 +178,11 @@ def get_lhotse_dataloader_from_config(
if config.noise_path is not None:
noise = CutSet.from_file(config.noise_path)
cuts = cuts.mix(
cuts=noise, snr=config.noise_snr, mix_prob=config.noise_mix_prob, seed="trng", random_mix_offset=True
cuts=noise,
snr=config.noise_snr,
mix_prob=config.noise_mix_prob,
seed=config.shard_seed,
random_mix_offset=True,
)

# 2.b. On-the-fly speed perturbation.
Expand Down Expand Up @@ -235,7 +240,7 @@ def get_lhotse_dataloader_from_config(
shuffle=config.shuffle,
drop_last=config.drop_last,
shuffle_buffer_size=config.shuffle_buffer_size,
seed=config.seed,
seed=config.shard_seed,
num_buckets=config.num_buckets,
duration_bins=config.bucket_duration_bins,
num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate,
Expand All @@ -257,7 +262,7 @@ def get_lhotse_dataloader_from_config(
shuffle=config.shuffle,
drop_last=config.drop_last,
shuffle_buffer_size=config.shuffle_buffer_size,
seed=config.seed,
seed=config.shard_seed,
rank=0 if is_tarred else global_rank,
world_size=1 if is_tarred else world_size,
)
Expand Down Expand Up @@ -289,7 +294,7 @@ def get_lhotse_dataloader_from_config(
# This together with infinite datasets removes the need to split data across nodes/workers.
dloader_kwargs = dict(
dataset=IterableDatasetWrapper(dataset=dataset, sampler=sampler),
worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size),
worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size, seed=seed),
persistent_workers=config.num_workers > 0, # helps Lhotse Shar maintain shuffling state
)
else:
Expand Down
15 changes: 6 additions & 9 deletions tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def test_dataloader_from_nemo_manifest(nemo_manifest_path: Path):
assert b["audio"].shape[0] == b["audio_lens"].shape[0] == 1


class _Identity:
def __getitem__(self, cuts):
return cuts


def test_dataloader_from_nemo_manifest_has_custom_fields(nemo_manifest_path: Path):
config = OmegaConf.create(
{
Expand All @@ -356,11 +361,7 @@ def test_dataloader_from_nemo_manifest_has_custom_fields(nemo_manifest_path: Pat
}
)

class _IdentityDataset:
def __getitem__(self, cuts):
return cuts

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=_IdentityDataset())
dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=_Identity())

batch = next(iter(dl))
for cut in batch:
Expand Down Expand Up @@ -852,10 +853,6 @@ def test_lhotse_cuts_resolve_relative_paths(tmp_path: Path):
{"cuts_path": cuts_path, "sample_rate": 16000, "use_lhotse": True, "num_workers": 0, "batch_size": 2,}
)

class _Identity(torch.utils.data.Dataset):
def __getitem__(self, x):
return x

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=_Identity())

batches = [batch for batch in dl]
Expand Down
162 changes: 162 additions & 0 deletions tests/collections/common/test_lhotse_multirank_rng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from io import BytesIO
from pathlib import Path

import pytest
from lhotse import CutSet
from lhotse.serialization import load_jsonl, save_to_jsonl
from lhotse.shar.writers import JsonlShardWriter, TarWriter
from lhotse.testing.dummies import DummyManifest
from omegaconf import OmegaConf

from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config


class _Identity:
def __getitem__(self, cuts):
return cuts


@pytest.fixture(scope="session")
def cutset_path(tmp_path_factory) -> Path:
"""10 utterances of length 1s as a Lhotse CutSet."""
cuts = DummyManifest(CutSet, begin_id=0, end_id=10, with_data=True)
for c in cuts:
c.features = None
c.custom = None
c.supervisions[0].custom = None

tmp_path = tmp_path_factory.mktemp("data")
p = tmp_path / "cuts.jsonl.gz"
pa = tmp_path / "audio"
cuts.save_audios(pa).to_file(p)
return p


@pytest.fixture(scope="session")
def nemo_manifest_path(cutset_path: Path):
"""10 utterances of length 1s as a NeMo manifest."""
nemo = []
for idx, c in enumerate(CutSet.from_file(cutset_path)):
nemo.append(
{"audio_filepath": c.recording.sources[0].source, "text": f"irrelevant-{idx}", "duration": c.duration,}
)
p = cutset_path.parent / "nemo_manifest.json"
save_to_jsonl(nemo, p)
return p


@pytest.fixture(scope="session")
def nemo_tarred_manifest_path(nemo_manifest_path: Path) -> tuple[str, str]:
"""5 shards, each with 2 utterances."""
root = nemo_manifest_path.parent / "nemo_tar"
root.mkdir(exist_ok=True)
with TarWriter(f"{root}/audios_%01d.tar", shard_size=2) as tar_writer, JsonlShardWriter(
f"{root}/manifest_%01d.jsonl", shard_size=2
) as mft_writer:
for idx, d in enumerate(load_jsonl(nemo_manifest_path)):
p = d["audio_filepath"]
name = Path(p).name
with open(p, "rb") as f:
tar_writer.write(name, BytesIO(f.read()))
mft_writer.write({**d, "audio_filepath": name, "shard_id": idx // 2})
return f"{root}/manifest__OP_0..4_CL_.jsonl", f"{root}/audios__OP_0..4_CL_.tar"


def test_dataloader_multiple_ranks_deterministic_rng(nemo_tarred_manifest_path: tuple[str, str]):
json_mft, tar_mft = nemo_tarred_manifest_path
config = OmegaConf.create(
{
"manifest_filepath": json_mft,
"tarred_audio_filepaths": tar_mft,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 1,
# lhotse specific
"use_bucketing": True,
"num_buckets": 2,
"drop_last": False,
"batch_duration": 4.0, # seconds
"quadratic_duration": 15.0, # seconds
"shuffle_buffer_size": 10,
"bucket_buffer_size": 100,
"seed": 0,
"shard_seed": "randomized",
}
)

# Data parallel, rank 0
dp0 = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity())

# Data parallel, rank 0 copy (is the iteration deterministic? -> yes)
dp0_cpy = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity(),)

# Data parallel, rank 0, incremented seed (paranoia mode: does the iteration order change with the seed? -> yes)
config2 = config.copy()
config2["seed"] = config2["seed"] + 1
dp0_incrseed = get_lhotse_dataloader_from_config(config=config2, global_rank=0, world_size=2, dataset=_Identity(),)

# Data parallel, rank 1 (is data different on each DP rank? -> yes)
dp1 = get_lhotse_dataloader_from_config(config=config, global_rank=1, world_size=2, dataset=_Identity())

dloaders = zip(*[iter(dl) for dl in (dp0, dp0_cpy, dp0_incrseed, dp1)])

for i in range(5):
b0, b0_cpy, b0_incrseed, b1 = next(dloaders)
assert b0 == b0_cpy
assert b0 != b1
assert b0_incrseed != b1
assert b0 != b0_incrseed


def test_dataloader_multiple_ranks_trng(nemo_tarred_manifest_path: tuple[str, str]):
"""
This test is the same as ``test_dataloader_multiple_ranks_deterministic_rng``,
except that we set ``shard_seed="trng"`` which causes the seed to be lazily
resolved in subprocesses (resolved => being drawn using OS's TRNG).
Therefore, we don't expect any reproducibility.
"""
json_mft, tar_mft = nemo_tarred_manifest_path
config = OmegaConf.create(
{
"manifest_filepath": json_mft,
"tarred_audio_filepaths": tar_mft,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 1,
# lhotse specific
"use_bucketing": True,
"num_buckets": 2,
"drop_last": False,
"batch_duration": 4.0, # seconds
"quadratic_duration": 15.0, # seconds
"shuffle_buffer_size": 10,
"bucket_buffer_size": 100,
"seed": 0,
"shard_seed": "trng",
}
)

# Data parallel, rank 0
dp0 = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity())

# Data parallel, rank 0 copy (is the iteration deterministic? -> no, trng)
dp0_cpy = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity(),)

# Data parallel, rank 0, incremented seed (paranoia mode: does the iteration order change with the seed? -> yes)
config2 = config.copy()
config2["seed"] = config2["seed"] + 1
dp0_incrseed = get_lhotse_dataloader_from_config(config=config2, global_rank=0, world_size=2, dataset=_Identity(),)

# Data parallel, rank 1 (is data different on each DP rank? -> yes)
dp1 = get_lhotse_dataloader_from_config(config=config, global_rank=1, world_size=2, dataset=_Identity())

dloaders = zip(*[iter(dl) for dl in (dp0, dp0_cpy, dp0_incrseed, dp1)])

for i in range(5):
b0, b0_cpy, b0_incrseed, b1 = next(dloaders)
assert b0 != b0_cpy
assert b0 != b1
assert b0_incrseed != b1
assert b0 != b0_incrseed
Loading