Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Frame-VAD model and datasets #6441

Merged
merged 12 commits into from
May 2, 2023
452 changes: 446 additions & 6 deletions nemo/collections/asr/data/audio_to_label.py

Large diffs are not rendered by default.

87 changes: 87 additions & 0 deletions nemo/collections/asr/data/audio_to_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# limitations under the License.
import copy

from omegaconf import DictConfig

from nemo.collections.asr.data import audio_to_label
from nemo.collections.asr.data.audio_to_text_dataset import convert_to_config_list, get_chain_dataset
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.common.data.dataset import ConcatDataset


Expand Down Expand Up @@ -217,3 +220,87 @@ def get_tarred_speech_label_dataset(
datasets.append(dataset)

return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)


def get_audio_multi_label_dataset(cfg: DictConfig) -> audio_to_label.AudioToMultiLabelDataset:
if "augmentor" in cfg:
augmentor = process_augmentations(cfg.augmentor)
else:
augmentor = None

dataset = audio_to_label.AudioToMultiLabelDataset(
manifest_filepath=cfg.get("manifest_filepath"),
sample_rate=cfg.get("sample_rate"),
labels=cfg.get("labels", None),
int_values=cfg.get("int_values", False),
augmentor=augmentor,
min_duration=cfg.get("min_duration", None),
max_duration=cfg.get("max_duration", None),
trim_silence=cfg.get("trim_silence", False),
is_regression_task=cfg.get("is_regression_task", False),
cal_labels_occurrence=cfg.get("cal_labels_occurrence", False),
delimiter=cfg.get("delimiter", None),
normalize_audio_db=cfg.get("normalize_audio_db", False),
normalize_audio_db_target=cfg.get("normalize_audio_db_target", -20),
)
return dataset


def get_tarred_audio_multi_label_dataset(
cfg: DictConfig, shuffle_n: int, global_rank: int, world_size: int
) -> audio_to_label.TarredAudioToMultiLabelDataset:

if "augmentor" in cfg:
augmentor = process_augmentations(cfg.augmentor)
else:
augmentor = None

tarred_audio_filepaths = cfg['tarred_audio_filepaths']
manifest_filepaths = cfg['manifest_filepath']
datasets = []
tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
manifest_filepaths = convert_to_config_list(manifest_filepaths)

bucketing_weights = cfg.get('bucketing_weights', None) # For upsampling buckets
if bucketing_weights:
for idx, weight in enumerate(bucketing_weights):
if not isinstance(weight, int) or weight <= 0:
raise ValueError(f"bucket weights must be positive integers")

if len(manifest_filepaths) != len(tarred_audio_filepaths):
raise ValueError(
f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
)

for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
zip(tarred_audio_filepaths, manifest_filepaths)
):
if len(tarred_audio_filepath) == 1:
tarred_audio_filepath = tarred_audio_filepath[0]

dataset = audio_to_label.TarredAudioToMultiLabelDataset(
audio_tar_filepaths=tarred_audio_filepath,
manifest_filepath=manifest_filepath,
sample_rate=cfg["sample_rate"],
labels=cfg['labels'],
shuffle_n=shuffle_n,
int_values=cfg.get("int_values", False),
augmentor=augmentor,
min_duration=cfg.get('min_duration', None),
max_duration=cfg.get('max_duration', None),
trim_silence=cfg.get('trim_silence', False),
is_regression_task=cfg.get('is_regression_task', False),
delimiter=cfg.get("delimiter", None),
shard_strategy=cfg.get('tarred_shard_strategy', 'scatter'),
global_rank=global_rank,
world_size=world_size,
normalize_audio_db=cfg.get("normalize_audio_db", False),
normalize_audio_db_target=cfg.get("normalize_audio_db_target", -20),
)

if bucketing_weights:
[datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
else:
datasets.append(dataset)

return get_chain_dataset(datasets=datasets, ds_config=cfg, rank=global_rank)
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel
from nemo.collections.asr.models.classification_models import EncDecClassificationModel
from nemo.collections.asr.models.classification_models import EncDecClassificationModel, EncDecFrameClassificationModel
from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
Expand Down
Loading