forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improved random seed configuration for Lhotse dataloaders with docs (N…
…VIDIA#9001) * Improving RNG seeding with Lhotse dataloading Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Add documentation about random seeds Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Add doc about managing random seed Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Apply suggestions from code review Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> Signed-off-by: Piotr Żelasko <petezor@gmail.com> --------- Signed-off-by: Piotr Żelasko <petezor@gmail.com> Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com>
- Loading branch information
1 parent
bf2f6f0
commit d6e8ff1
Showing
5 changed files
with
235 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from io import BytesIO | ||
from pathlib import Path | ||
|
||
import pytest | ||
from lhotse import CutSet | ||
from lhotse.serialization import load_jsonl, save_to_jsonl | ||
from lhotse.shar.writers import JsonlShardWriter, TarWriter | ||
from lhotse.testing.dummies import DummyManifest | ||
from omegaconf import OmegaConf | ||
|
||
from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config | ||
|
||
|
||
class _Identity: | ||
def __getitem__(self, cuts): | ||
return cuts | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def cutset_path(tmp_path_factory) -> Path: | ||
"""10 utterances of length 1s as a Lhotse CutSet.""" | ||
cuts = DummyManifest(CutSet, begin_id=0, end_id=10, with_data=True) | ||
for c in cuts: | ||
c.features = None | ||
c.custom = None | ||
c.supervisions[0].custom = None | ||
|
||
tmp_path = tmp_path_factory.mktemp("data") | ||
p = tmp_path / "cuts.jsonl.gz" | ||
pa = tmp_path / "audio" | ||
cuts.save_audios(pa).to_file(p) | ||
return p | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def nemo_manifest_path(cutset_path: Path): | ||
"""10 utterances of length 1s as a NeMo manifest.""" | ||
nemo = [] | ||
for idx, c in enumerate(CutSet.from_file(cutset_path)): | ||
nemo.append( | ||
{"audio_filepath": c.recording.sources[0].source, "text": f"irrelevant-{idx}", "duration": c.duration,} | ||
) | ||
p = cutset_path.parent / "nemo_manifest.json" | ||
save_to_jsonl(nemo, p) | ||
return p | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def nemo_tarred_manifest_path(nemo_manifest_path: Path) -> tuple[str, str]: | ||
"""5 shards, each with 2 utterances.""" | ||
root = nemo_manifest_path.parent / "nemo_tar" | ||
root.mkdir(exist_ok=True) | ||
with TarWriter(f"{root}/audios_%01d.tar", shard_size=2) as tar_writer, JsonlShardWriter( | ||
f"{root}/manifest_%01d.jsonl", shard_size=2 | ||
) as mft_writer: | ||
for idx, d in enumerate(load_jsonl(nemo_manifest_path)): | ||
p = d["audio_filepath"] | ||
name = Path(p).name | ||
with open(p, "rb") as f: | ||
tar_writer.write(name, BytesIO(f.read())) | ||
mft_writer.write({**d, "audio_filepath": name, "shard_id": idx // 2}) | ||
return f"{root}/manifest__OP_0..4_CL_.jsonl", f"{root}/audios__OP_0..4_CL_.tar" | ||
|
||
|
||
def test_dataloader_multiple_ranks_deterministic_rng(nemo_tarred_manifest_path: tuple[str, str]): | ||
json_mft, tar_mft = nemo_tarred_manifest_path | ||
config = OmegaConf.create( | ||
{ | ||
"manifest_filepath": json_mft, | ||
"tarred_audio_filepaths": tar_mft, | ||
"sample_rate": 16000, | ||
"shuffle": True, | ||
"use_lhotse": True, | ||
"num_workers": 1, | ||
# lhotse specific | ||
"use_bucketing": True, | ||
"num_buckets": 2, | ||
"drop_last": False, | ||
"batch_duration": 4.0, # seconds | ||
"quadratic_duration": 15.0, # seconds | ||
"shuffle_buffer_size": 10, | ||
"bucket_buffer_size": 100, | ||
"seed": 0, | ||
"shard_seed": "randomized", | ||
} | ||
) | ||
|
||
# Data parallel, rank 0 | ||
dp0 = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity()) | ||
|
||
# Data parallel, rank 0 copy (is the iteration deterministic? -> yes) | ||
dp0_cpy = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity(),) | ||
|
||
# Data parallel, rank 0, incremented seed (paranoia mode: does the iteration order change with the seed? -> yes) | ||
config2 = config.copy() | ||
config2["seed"] = config2["seed"] + 1 | ||
dp0_incrseed = get_lhotse_dataloader_from_config(config=config2, global_rank=0, world_size=2, dataset=_Identity(),) | ||
|
||
# Data parallel, rank 1 (is data different on each DP rank? -> yes) | ||
dp1 = get_lhotse_dataloader_from_config(config=config, global_rank=1, world_size=2, dataset=_Identity()) | ||
|
||
dloaders = zip(*[iter(dl) for dl in (dp0, dp0_cpy, dp0_incrseed, dp1)]) | ||
|
||
for i in range(5): | ||
b0, b0_cpy, b0_incrseed, b1 = next(dloaders) | ||
assert b0 == b0_cpy | ||
assert b0 != b1 | ||
assert b0_incrseed != b1 | ||
assert b0 != b0_incrseed | ||
|
||
|
||
def test_dataloader_multiple_ranks_trng(nemo_tarred_manifest_path: tuple[str, str]): | ||
""" | ||
This test is the same as ``test_dataloader_multiple_ranks_deterministic_rng``, | ||
except that we set ``shard_seed="trng"`` which causes the seed to be lazily | ||
resolved in subprocesses (resolved => being drawn using OS's TRNG). | ||
Therefore, we don't expect any reproducibility. | ||
""" | ||
json_mft, tar_mft = nemo_tarred_manifest_path | ||
config = OmegaConf.create( | ||
{ | ||
"manifest_filepath": json_mft, | ||
"tarred_audio_filepaths": tar_mft, | ||
"sample_rate": 16000, | ||
"shuffle": True, | ||
"use_lhotse": True, | ||
"num_workers": 1, | ||
# lhotse specific | ||
"use_bucketing": True, | ||
"num_buckets": 2, | ||
"drop_last": False, | ||
"batch_duration": 4.0, # seconds | ||
"quadratic_duration": 15.0, # seconds | ||
"shuffle_buffer_size": 10, | ||
"bucket_buffer_size": 100, | ||
"seed": 0, | ||
"shard_seed": "trng", | ||
} | ||
) | ||
|
||
# Data parallel, rank 0 | ||
dp0 = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity()) | ||
|
||
# Data parallel, rank 0 copy (is the iteration deterministic? -> no, trng) | ||
dp0_cpy = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=2, dataset=_Identity(),) | ||
|
||
# Data parallel, rank 0, incremented seed (paranoia mode: does the iteration order change with the seed? -> yes) | ||
config2 = config.copy() | ||
config2["seed"] = config2["seed"] + 1 | ||
dp0_incrseed = get_lhotse_dataloader_from_config(config=config2, global_rank=0, world_size=2, dataset=_Identity(),) | ||
|
||
# Data parallel, rank 1 (is data different on each DP rank? -> yes) | ||
dp1 = get_lhotse_dataloader_from_config(config=config, global_rank=1, world_size=2, dataset=_Identity()) | ||
|
||
dloaders = zip(*[iter(dl) for dl in (dp0, dp0_cpy, dp0_incrseed, dp1)]) | ||
|
||
for i in range(5): | ||
b0, b0_cpy, b0_incrseed, b1 = next(dloaders) | ||
assert b0 != b0_cpy | ||
assert b0 != b1 | ||
assert b0_incrseed != b1 | ||
assert b0 != b0_incrseed |