From 673b3eb0a66621b5355ded4e3b0222f907c77808 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 25 Jan 2024 15:59:20 -0500 Subject: [PATCH 1/7] Create a separate CanaryDataset and use it inside `transformer_bpe_models.py`. Ditches `token_sequence_format`. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_canary.py | 104 ++++++++++++++++++ .../asr/data/audio_to_text_lhotse.py | 80 +------------- .../asr/models/transformer_bpe_models.py | 7 +- 3 files changed, 109 insertions(+), 82 deletions(-) create mode 100644 nemo/collections/asr/data/audio_to_text_canary.py diff --git a/nemo/collections/asr/data/audio_to_text_canary.py b/nemo/collections/asr/data/audio_to_text_canary.py new file mode 100644 index 000000000000..e2f89bf78c2a --- /dev/null +++ b/nemo/collections/asr/data/audio_to_text_canary.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.utils.data +from lhotse.cut import MixedCut, MonoCut +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors + +from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper + + +class CanaryDataset(torch.utils.data.Dataset): + """ + This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`. + It is a Lhotse-style dataset that converts a mini-batch of Cuts into tensors. + The main difference from ``LhotseSpeechToTextBpeDataset`` is that we introduce + a special prompt format for Canary model, which has an encoder-decoder architecture. + """ + + def __init__(self, tokenizer): + super().__init__() + self.tokenizer = TokenizerWrapper(tokenizer) + self.load_audio = AudioSamples(fault_tolerant=True) + self.padding_value = self.tokenizer._tokenizer.pad_id + + def __getitem__(self, cuts) -> tuple[torch.Tensor, ...]: + audio, audio_lens, cuts = self.load_audio(cuts) + + tokens = [self.tokenizer(c.supervisions[0].text, c.supervisions[0].language) for c in cuts] + tokens = self._canary_format(tokens, cuts) + tokens = [torch.as_tensor(t) for t in tokens] + token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) + tokens = collate_vectors(tokens, padding_value=self.padding_value) + + return audio, audio_lens, tokens, token_lens + + def _canary_format(self, tokens, cuts): + """ + prepend and append control tokens to the token sequence as per canary format + + Format: + sot, src_lang_id/no_speech, transcribe/translate, tgt_lang_id, text, eot + """ + canary_tokens = [] + for t, c in zip(tokens, cuts): + if isinstance(c, MixedCut): + c = c._first_non_padding_cut + assert isinstance(c, MonoCut), "Expected MonoCut." + + c_t = [] # canary_tokens for this cut + + # bos + c_t.append(self.tokenizer._tokenizer.bos_id) + + # if len(t) is 0 append no-speech token + if len(t) == 0: + c_t.append(self.tokenizer._tokenizer.nospeech_id) + else: + # src_lang_id/no_speech + src_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['source_lang']) + c_t.append(src_lang_id) + + # task + task = c.custom['taskname'] + if task == 'asr': + c_t.append(self.tokenizer._tokenizer.transcribe_id) + elif task == 's2t_translation': + c_t.append(self.tokenizer._tokenizer.translate_id) + else: + raise ValueError(f"Unknown task: {task}") + + # tgt_lang_id + tgt_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['target_lang']) + c_t.append(tgt_lang_id) + + # PnC + pnc = f"{c.custom['pnc']}".lower().strip() # to account for bool or str + if pnc in set(['yes', 'true']): + c_t.append(self.tokenizer._tokenizer.pnc_id) + elif pnc in set(['no', 'false']): + c_t.append(self.tokenizer._tokenizer.nopnc_id) + else: + raise ValueError(f"Unknown PnC: {pnc}") + + # text + c_t.extend(t) + + # eos + c_t.append(self.tokenizer._tokenizer.eos_id) + + canary_tokens.append(c_t) + + return canary_tokens diff --git a/nemo/collections/asr/data/audio_to_text_lhotse.py b/nemo/collections/asr/data/audio_to_text_lhotse.py index b8f1867624d4..21cecedeadbb 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse.py @@ -15,7 +15,6 @@ from typing import Dict, Optional, Tuple import torch.utils.data -from lhotse.cut import MixedCut, MonoCut from lhotse.dataset import AudioSamples from lhotse.dataset.collation import collate_vectors @@ -44,91 +43,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), } - def __init__(self, tokenizer, token_sequence_format: str = None): + def __init__(self, tokenizer): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) self.load_audio = AudioSamples(fault_tolerant=True) - assert token_sequence_format is None or token_sequence_format in [ - 'canary' - ], f"Unsupported token_sequence_format: {token_sequence_format}" - self.token_sequence_format = token_sequence_format def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: audio, audio_lens, cuts = self.load_audio(cuts) - - tokens = [self.tokenizer(c.supervisions[0].text, c.supervisions[0].language) for c in cuts] - if self.token_sequence_format == 'canary': - tokens = self._canary_format(tokens, cuts) - tokens = [torch.as_tensor(t) for t in tokens] - + tokens = [torch.as_tensor(self.tokenizer(c.supervisions[0].text, c.supervisions[0].language)) for c in cuts] token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) - - if self.token_sequence_format == 'canary': - padding_value = self.tokenizer._tokenizer.pad_id - else: - padding_value = 0 - tokens = collate_vectors(tokens, padding_value=padding_value) - + tokens = collate_vectors(tokens, padding_value=0) return audio, audio_lens, tokens, token_lens - def _canary_format(self, tokens, cuts): - """ - prepend and append control tokens to the token sequence as per canary format - - Format: - sot, src_lang_id/no_speech, transcribe/translate, tgt_lang_id, text, eot - """ - canary_tokens = [] - for t, c in zip(tokens, cuts): - if isinstance(c, MixedCut): - c = c._first_non_padding_cut - assert isinstance(c, MonoCut), "Expected MonoCut." - - c_t = [] # canary_tokens for this cut - - # bos - c_t.append(self.tokenizer._tokenizer.bos_id) - - # if len(t) is 0 append no-speech token - if len(t) == 0: - c_t.append(self.tokenizer._tokenizer.nospeech_id) - else: - # src_lang_id/no_speech - src_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['source_lang']) - c_t.append(src_lang_id) - - # task - task = c.custom['taskname'] - if task == 'asr': - c_t.append(self.tokenizer._tokenizer.transcribe_id) - elif task == 's2t_translation': - c_t.append(self.tokenizer._tokenizer.translate_id) - else: - raise ValueError(f"Unknown task: {task}") - - # tgt_lang_id - tgt_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['target_lang']) - c_t.append(tgt_lang_id) - - # PnC - pnc = f"{c.custom['pnc']}".lower().strip() # to account for bool or str - if pnc in set(['yes', 'true']): - c_t.append(self.tokenizer._tokenizer.pnc_id) - elif pnc in set(['no', 'false']): - c_t.append(self.tokenizer._tokenizer.nopnc_id) - else: - raise ValueError(f"Unknown PnC: {pnc}") - - # text - c_t.extend(t) - - # eos - c_t.append(self.tokenizer._tokenizer.eos_id) - - canary_tokens.append(c_t) - - return canary_tokens - class TokenizerWrapper: """ diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 471030ee3127..6d94a9e53a5e 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -28,8 +28,8 @@ from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text_canary import CanaryDataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs -from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.parts.mixins import ASRBPEMixin from nemo.collections.asr.parts.utils import manifest_utils @@ -367,9 +367,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): config, global_rank=self.global_rank, world_size=self.world_size, - dataset=LhotseSpeechToTextBpeDataset( - tokenizer=self.tokenizer, token_sequence_format=config.get("token_sequence_format", None), - ), + dataset=CanaryDataset(tokenizer=self.tokenizer), ) dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( @@ -750,7 +748,6 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'is_tarred': False, 'batch_size': 1, 'force_strip_pnc': False, - 'token_sequence_format': "canary", 'use_lhotse': True, 'lhotse': { 'use_bucketing': False, From 048af28a5c6af42aaa9f33aa200fcb3847939c25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 26 Jan 2024 14:46:06 -0500 Subject: [PATCH 2/7] =?UTF-8?q?[canary]=20Refactor:=20move=20changes=20in?= =?UTF-8?q?=20transformer=5Fbpe=5Fmodels.py=20to=20Canar=E2=80=A6=20(#8252?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [canary] Refactor: move changes in transformer_bpe_models.py to CanaryModel Signed-off-by: Piotr Żelasko * Rename `CanaryModel` to `EncDecMultiTaskModel` and remove inheritance from `EncDecTransfModelBPE`; add a separate config for this model Signed-off-by: Piotr Żelasko --------- Signed-off-by: Piotr Żelasko --- .../speech_multitask/fast-conformer_aed.yaml | 230 ++++++ nemo/collections/asr/models/__init__.py | 1 + .../asr/models/aed_multitask_models.py | 721 ++++++++++++++++++ .../asr/models/transformer_bpe_models.py | 151 +--- 4 files changed, 965 insertions(+), 138 deletions(-) create mode 100644 examples/asr/conf/speech_multitask/fast-conformer_aed.yaml create mode 100644 nemo/collections/asr/models/aed_multitask_models.py diff --git a/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml new file mode 100644 index 000000000000..6283b364bec4 --- /dev/null +++ b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml @@ -0,0 +1,230 @@ +# It contains the default values for training an autoregressive FastConformer-Transformer ST model with sub-word encoding. + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file. +# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes +# It is recommended to initialize FastConformer with ASR pre-trained encoder for better accuracy and faster convergence + +name: "FastConformer-Transformer-MultiTask" + +# Initialize model encoder with pre-trained ASR FastConformer encoder for faster convergence and improved accuracy +init_from_nemo_model: + model0: + path: ??? + include: ["preprocessor", "encoder"] + +model: + _target_: nemo.collections.asr.models.EncDecMultiTaskModel + sample_rate: 16000 + label_smoothing: 0.0 + log_prediction: true # enables logging sample predictions in the output during training + + train_ds: + is_tarred: true + tarred_audio_filepaths: ??? + manifest_filepath: ??? + sample_rate: 16000 + shuffle: true + trim_silence: false + num_workers: 8 + use_lhotse: true + batch_size: None + batch_duration: 360 + quadratic_duration: 20 + use_bucketing: True + num_buckets: 20 + bucket_buffer_size: 20000 + shuffle_buffer_size: 10000 + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 4 + pin_memory: true + use_start_end_token: true + use_lhotse: true + use_bucketing: false + + test_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 4 + pin_memory: true + use_start_end_token: true + use_lhotse: true + use_bucketing: false + + # recommend small vocab size of 128 or 256 when using 4x sub-sampling + # you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + pad_value: 0.0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + # you may use lower time_masks for smaller models to have a faster convergence + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + + # Sub-sampling params + subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory + subsampling_factor: 8 # must be power of 2 + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + reduction: null + reduction_position: null + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: batch_norm + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + transf_encoder: + num_layers: 0 + hidden_size: 512 + inner_size: 2048 + num_attention_heads: 8 + ffn_dropout: 0.1 + attn_score_dropout: 0.1 + attn_layer_dropout: 0.1 + + transf_decoder: + library: nemo + model_name: null + pretrained: false + max_sequence_length: 512 + num_token_types: 0 + embedding_dropout: 0.1 + learn_positional_encodings: false + hidden_size: 512 + inner_size: 2048 + num_layers: 6 + num_attention_heads: 4 + ffn_dropout: 0.1 + attn_score_dropout: 0.1 + attn_layer_dropout: 0.1 + hidden_act: relu + pre_ln: true + pre_ln_final_layer_norm: true + + head: + num_layers: 1 + activation: relu + log_softmax: true + dropout: 0.0 + use_transformer_init: true + + beam_search: + beam_size: 4 + len_pen: 0.0 + max_generation_delta: 50 + + optim: + name: adam + lr: 0.0001 + # optimizer arguments + betas: [0.9, 0.98] + # less necessity for weight_decay as we already have large augmentations with SpecAug + # you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used + # weight decay of 0.0 with lr of 2.0 also works fine + #weight_decay: 1e-3 + + # scheduler setup + sched: + name: InverseSquareRootAnnealing + #d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 1000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 100 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 100 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_sacreBLEU" + mode: "max" + save_top_k: 3 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to True to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 34f2c4f62e29..2b5659066daa 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel from nemo.collections.asr.models.asr_model import ASRModel from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel from nemo.collections.asr.models.classification_models import EncDecClassificationModel, EncDecFrameClassificationModel diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py new file mode 100644 index 000000000000..7d4e036e22e9 --- /dev/null +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -0,0 +1,721 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import os +import re +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import editdistance +import torch +import torch.distributed as dist +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data.audio_to_text_canary import CanaryDataset +from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs +from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.collections.asr.parts.utils import manifest_utils +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.losses import SmoothedCrossEntropyLoss +from nemo.collections.common.metrics import GlobalAverageLossMetric +from nemo.collections.common.parts import transformer_weights_init +from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from nemo.core.classes.common import typecheck +from nemo.core.neural_types import ( + AudioSignal, + ChannelType, + LabelsType, + LengthsType, + LogprobsType, + MaskType, + NeuralType, + SpectrogramType, +) +from nemo.utils import logging + +try: + from sacrebleu import corpus_bleu + + from nemo.collections.nlp.modules.common import TokenClassifier + from nemo.collections.nlp.modules.common.lm_utils import get_transformer + from nemo.collections.nlp.modules.common.transformer import BeamSearchSequenceGenerator, TransformerEncoder + + NLP_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + NLP_AVAILABLE = False + logging.warning("Could not import NeMo NLP collection which is required for speech translation model.") + +__all__ = ['EncDecMultiTaskModel'] + + +def lens_to_mask(lens, max_length): + batch_size = lens.shape[0] + mask = torch.arange(max_length).repeat(batch_size, 1).to(lens.device) < lens[:, None] + return mask + + +class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin): + """Base class for encoder decoder CTC-based models.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + # Setup the tokenizer + self._setup_tokenizer(cfg.tokenizer) + + super().__init__(cfg=cfg, trainer=trainer) + + # Setup audio preprocessor + self.preprocessor = EncDecMultiTaskModel.from_config_dict(self.cfg.preprocessor) + + # Setup audio encoder + self.encoder = EncDecMultiTaskModel.from_config_dict(self.cfg.encoder) + + # Add projection layer if encoder and decoder differ in hidden size + if self.cfg.encoder['d_model'] != self.cfg.transf_decoder['hidden_size']: + self.adapter = torch.nn.Linear(self.cfg.encoder['d_model'], self.cfg.transf_decoder['hidden_size']) + else: + self.adapter = torch.nn.Identity() + + transf_encoder_cfg_dict = OmegaConf.to_container(cfg.get('transf_encoder')) + + # Whether to add Transformer Encoder block between Conformer and Transformer Decoder + self.use_transf_encoder = False + if transf_encoder_cfg_dict['num_layers'] > 0: + self.use_transf_encoder = True + + self.transf_encoder = TransformerEncoder( + num_layers=transf_encoder_cfg_dict['num_layers'], + hidden_size=transf_encoder_cfg_dict['hidden_size'], + inner_size=transf_encoder_cfg_dict['inner_size'], + mask_future=False, + num_attention_heads=transf_encoder_cfg_dict['num_attention_heads'], + attn_score_dropout=transf_encoder_cfg_dict['attn_score_dropout'], + attn_layer_dropout=transf_encoder_cfg_dict['attn_layer_dropout'], + ffn_dropout=transf_encoder_cfg_dict['ffn_dropout'], + pre_ln=transf_encoder_cfg_dict.get('pre_ln', True), + pre_ln_final_layer_norm=transf_encoder_cfg_dict.get('pre_ln_final_layer_norm', True), + ) + std_init_range = 1 / transf_encoder_cfg_dict['hidden_size'] ** 0.5 + self.transf_encoder.apply(lambda module: transformer_weights_init(module, std_init_range)) + + transf_decoder_cfg_dict = OmegaConf.to_container(cfg.get('transf_decoder')) + + # Transformer decoder + vocab_size = 8 * ceil(self.tokenizer.vocab_size / 8) + transf_decoder_cfg_dict['vocab_size'] = vocab_size + library = transf_decoder_cfg_dict.pop('library', 'nemo') + model_name = transf_decoder_cfg_dict.pop('model_name', None) + pretrained = transf_decoder_cfg_dict.pop('pretrained', False) + self.transf_decoder = get_transformer( + library=library, + model_name=model_name, + pretrained=pretrained, + config_dict=transf_decoder_cfg_dict, + encoder=False, + pre_ln_final_layer_norm=transf_decoder_cfg_dict.get("pre_ln_final_layer_norm", False), + ) + + self.log_softmax = TokenClassifier( + hidden_size=self.transf_decoder.hidden_size, + num_classes=vocab_size, + activation=self.cfg.head.activation, + log_softmax=self.cfg.head.log_softmax, + dropout=self.cfg.head.dropout, + use_transformer_init=self.cfg.head.use_transformer_init, + ) + self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight + std_init_range = 1 / self.transf_decoder.hidden_size ** 0.5 + self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range)) + self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) + + # Beam Search decoding + self.beam_search = BeamSearchSequenceGenerator( + embedding=self.transf_decoder.embedding, + decoder=self.transf_decoder.decoder, + log_softmax=self.log_softmax, + max_sequence_length=self.transf_decoder.max_sequence_length, + beam_size=self.cfg.beam_search.beam_size, + bos=self.tokenizer.bos_id, + pad=self.tokenizer.pad_id, + eos=self.tokenizer.eos_id, + len_pen=self.cfg.beam_search.len_pen, + max_delta_length=self.cfg.beam_search.max_generation_delta, + ) + # TO DO: remove this hardcoded context size for AR decoding + self.context_len_for_AR_decoding = 5 + + # Define autoregressive CE loss + self.transf_loss = SmoothedCrossEntropyLoss( + pad_id=self.tokenizer.pad_id, label_smoothing=self.cfg.label_smoothing + ) + + if hasattr(self.cfg, 'spec_augment') and self.cfg.spec_augment is not None: + self.spec_augmentation = EncDecMultiTaskModel.from_config_dict(self.cfg.spec_augment) + else: + self.spec_augmentation = None + + self.val_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) + + def change_decoding_strategy(self, cfg: DictConfig): + logging.info(f"Changing beam search decoding to {cfg}") + # Beam Search decoding + self.beam_search = BeamSearchSequenceGenerator( + embedding=self.transf_decoder.embedding, + decoder=self.transf_decoder.decoder, + log_softmax=self.log_softmax, + max_sequence_length=self.transf_decoder.max_sequence_length, + beam_size=cfg.beam_size, + bos=self.tokenizer.bos_id, + pad=self.tokenizer.pad_id, + eos=self.tokenizer.eos_id, + len_pen=cfg.len_pen, + max_delta_length=cfg.max_generation_delta, + ) + + @torch.no_grad() + def transcribe( + self, + paths2audio_files: Union[List[str], str], + batch_size: int = 4, + logprobs: bool = False, + return_hypotheses: bool = False, + num_workers: int = 0, + channel_selector: Optional[ChannelSelectorType] = None, + augmentor: DictConfig = None, + verbose: bool = True, + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + num_workers: (int) number of workers for DataLoader + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + verbose: (bool) whether to display tqdm progress bar + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses: + logging.warning("return_hypotheses=True is currently not supported, returning text instead.") + + if isinstance(paths2audio_files, list): + logging.info(f"Found paths2audio_files to be a list of {len(paths2audio_files)} items.") + logging.info(f"Assuming each item in paths2audio_files is a path to audio file.") + logging.info(f"Transcribing with default Canary setting of English without PnC.") + elif isinstance(paths2audio_files, str): + logging.info(f"Found paths2audio_files to be a string. Assuming it is a path to manifest file.") + assert os.path.exists(paths2audio_files), f"File {paths2audio_files} doesn't exist" + assert paths2audio_files.endswith('.json') or paths2audio_files.endswith( + '.jsonl' + ), f"File {paths2audio_files} must be a json or jsonl file" + + # load json lines + manifest_path = paths2audio_files # need to save this as we are overwriting paths2audio_files in nextline + paths2audio_files = manifest_utils.read_manifest(paths2audio_files) + + def _may_be_make_dict_and_fix_paths(json_items): + out_json_items = [] + for item in json_items: + if isinstance(item, str): + # assume it is a path to audio file + entry = { + 'audio_filepath': item, + 'duration': 100000, + 'source_lang': 'en', + 'taskname': 'asr', + 'target_lang': 'en', + 'pnc': 'no', + 'answer': 'nothing', + } + elif isinstance(item, dict): + entry = item + entry['audio_filepath'] = get_full_path(entry['audio_filepath'], manifest_file=manifest_path) + else: + raise ValueError(f"Expected str or dict, got {type(item)}") + out_json_items.append(entry) + return out_json_items + + paths2audio_files = _may_be_make_dict_and_fix_paths(paths2audio_files) + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + if num_workers is None: + num_workers = min(batch_size, os.cpu_count() - 1) + + # We will store transcriptions here + hypotheses = [] + + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + dither_value = self.preprocessor.featurizer.dither + pad_to_value = self.preprocessor.featurizer.pad_to + + try: + self.preprocessor.featurizer.dither = 0.0 + self.preprocessor.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + self.encoder.freeze() + self.transf_decoder.freeze() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + # _may_be_make_dict_and_fix_paths has already fixed the path and added other fields if needed + fp.write(json.dumps(audio_file) + '\n') + + config = { + 'paths2audio_files': paths2audio_files, + 'batch_size': batch_size, + 'temp_dir': tmpdir, + 'num_workers': num_workers, + 'channel_selector': channel_selector, + } + + if augmentor: + config['augmentor'] = augmentor + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose): + log_probs, encoded_len, enc_states, enc_mask = self.forward( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + ) + + beam_hypotheses = ( + self.beam_search( + encoder_hidden_states=enc_states, + encoder_input_mask=enc_mask, + return_beam_scores=False, + decoder_input_ids=test_batch[2][:, : self.context_len_for_AR_decoding].to(device) + if self.context_len_for_AR_decoding > 0 + else None, + ) + .detach() + .cpu() + .numpy() + ) + + beam_hypotheses = [ + self._strip_special_tokens(self.tokenizer.ids_to_text(hyp)) for hyp in beam_hypotheses + ] + + # TODO: add support for return_hypotheses=True @AlexGrinch + # if return_hypotheses: + # # dump log probs per file + # for idx in range(logits.shape[0]): + # current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + + hypotheses += beam_hypotheses + + del test_batch, log_probs, encoded_len, enc_states, enc_mask + finally: + # set mode back to its original value + self.train(mode=mode) + self.preprocessor.featurizer.dither = dither_value + self.preprocessor.featurizer.pad_to = pad_to_value + if mode is True: + self.encoder.unfreeze() + self.transf_decoder.unfreeze() + logging.set_verbosity(logging_level) + + return hypotheses + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + assert config.get("use_lhotse", False), ( + "Multi-task model only supports dataloading with Lhotse. " + "Please set config.{train,validation,test}_ds.use_lhotse=True" + ) + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=CanaryDataset(tokenizer=self.tokenizer), + ) + + def setup_training_data(self, train_data_config: Optional[DictConfig]): + + # create audio-only data loader + self._update_dataset_config(dataset_name='train', config=train_data_config) + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + # Need to set this because if using an IterableDataset, the length of the + # dataloader is the total number of samples rather than the number of batches, + # and this messes up the tqdm progress bar. So we set the number of steps manually + # (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, + # i.e. <= # training batches, and don't change it. Otherwise, adjust + # batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + return { + "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "transcript": NeuralType(('B', 'T'), LabelsType(), optional=True), + "transcript_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "sample_id": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "transf_log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "encoder_states": NeuralType(('B', 'T', 'D'), ChannelType()), + "encoder_mask": NeuralType(('B', 'T'), MaskType()), + } + + @typecheck() + def forward( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + transcript=None, + transcript_length=None, + ): + """ + Forward pass of the model. + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + processed_signal: Tensor that represents a batch of processed audio signals, + of shape (B, D, T) that has undergone processing via some DALI preprocessor. + processed_signal_length: Vector of length B, that contains the individual lengths of the + processed audio sequences. + Returns: + A tuple of 3 elements - + 1) The log probabilities tensor of shape [B, T, D]. + 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. + 3) The greedy token predictions of the model of shape [B, T] (via argmax) + """ + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) == False: + raise ValueError( + f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, length=input_signal_length + ) + + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + + enc_states = encoded.permute(0, 2, 1) + enc_states = self.adapter(enc_states) + enc_mask = lens_to_mask(encoded_len, enc_states.shape[1]).to(enc_states.dtype) + if self.use_transf_encoder: + enc_states = self.transf_encoder(encoder_states=enc_states, encoder_mask=enc_mask) + + transf_log_probs = None + if transcript is not None: + dec_mask = lens_to_mask(transcript_length, transcript.shape[1]).to(transcript.dtype) + dec_states = self.transf_decoder( + input_ids=transcript, decoder_mask=dec_mask, encoder_embeddings=enc_states, encoder_mask=enc_mask + ) + transf_log_probs = self.log_softmax(hidden_states=dec_states) + + return transf_log_probs, encoded_len, enc_states, enc_mask + + def compute_audio_loss(self, batch): + + if batch is None: + return 0 + + signal, signal_len, transcript, transcript_len = batch + input_ids, labels = transcript[:, :-1], transcript[:, 1:] + + transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( + input_signal=signal, + input_signal_length=signal_len, + transcript=input_ids, + transcript_length=transcript_len, + ) + + transf_loss = self.transf_loss(log_probs=transf_log_probs, labels=labels) + + return transf_loss + + # PTL-specific methods + def training_step(self, batch, batch_nb): + + audio_loss = self.compute_audio_loss(batch) + + tensorboard_logs = { + 'train_loss': audio_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + } + + return {'loss': audio_loss, 'log': tensorboard_logs} + + def _strip_special_tokens(self, text): + """ + assuming all special tokens are of format + Note that if any label/pred is of format , it will be stripped + """ + assert isinstance(text, str), f"Expected str, got {type(text)}" + text = re.sub(r'<[^>]+>', '', text) + # strip spaces at the beginning and end; + # this is training data artifact, will be fixed in future (@kpuvvada) + return text.strip() + + def validation_step(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): + signal, signal_len, transcript, transcript_len = batch + input_ids, labels = transcript[:, :-1], transcript[:, 1:] + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( + processed_signal=signal, + processed_signal_length=signal_len, + transcript=input_ids, + transcript_length=transcript_len, + ) + else: + transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( + input_signal=signal, + input_signal_length=signal_len, + transcript=input_ids, + transcript_length=transcript_len, + ) + + beam_hypotheses = self.beam_search( + encoder_hidden_states=enc_states, + encoder_input_mask=enc_mask, + return_beam_scores=False, + decoder_input_ids=input_ids[:, : self.context_len_for_AR_decoding] + if self.context_len_for_AR_decoding > 0 + else None, + ) + transf_loss = self.transf_loss(log_probs=transf_log_probs, labels=labels) + + ground_truths = [self.tokenizer.ids_to_text(sent) for sent in transcript.detach().cpu().tolist()] + translations = [self.tokenizer.ids_to_text(sent) for sent in beam_hypotheses.detach().cpu().tolist()] + + self.val_loss(loss=transf_loss, num_measurements=transf_log_probs.shape[0] * transf_log_probs.shape[1]) + + output_dict = { + f'{eval_mode}_loss': transf_loss, + 'translations': [self._strip_special_tokens(t) for t in translations], + 'ground_truths': [self._strip_special_tokens(g) for g in ground_truths], + } + + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(output_dict) + else: + self.validation_step_outputs.append(output_dict) + + return output_dict + + def test_step(self, batch, batch_idx, dataloader_idx=0): + return self.validation_step(batch, batch_idx, dataloader_idx, eval_mode="test") + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, eval_mode: str = "val"): + """ + Called at the end of validation to aggregate outputs. + :param outputs: list of individual outputs of each validation step. + """ + if not outputs: + return + + if isinstance(outputs[0], dict): + outputs = [outputs] + + for output in outputs: + eval_loss = getattr(self, 'val_loss').compute() + translations = list(itertools.chain(*[x['translations'] for x in output])) + ground_truths = list(itertools.chain(*[x['ground_truths'] for x in output])) + + # Gather translations and ground truths from all workers + tr_and_gt = [None for _ in range(self.world_size)] + # we also need to drop pairs where ground truth is an empty string + if self.world_size > 1: + dist.all_gather_object( + tr_and_gt, [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] + ) + else: + tr_and_gt[0] = [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] + + if self.global_rank == 0: + _translations = [] + _ground_truths = [] + for rank in range(0, self.world_size): + _translations += [t for (t, g) in tr_and_gt[rank]] + _ground_truths += [g for (t, g) in tr_and_gt[rank]] + + sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="13a") + sb_score = sacre_bleu.score * self.world_size + + wer_scores, wer_words = 0, 0 + for h, r in zip(_translations, _ground_truths): + wer_words += len(r.split()) + wer_scores += editdistance.eval(h.split(), r.split()) + wer_score = 1.0 * wer_scores * self.world_size / wer_words + + else: + sb_score = 0.0 + wer_score = 0.0 + + # To log via on_validation_epoch_end in modelPT.py + # remove (* self.world_size) if logging via on_validation_epoch_end + # tensorboard_logs = {} + # tensorboard_logs.update({f"{eval_mode}_loss": eval_loss}) + # tensorboard_logs.update({f"{eval_mode}_sacreBLEU": sb_score}) + # tensorboard_logs.update({f"{eval_mode}_WER": wer_score}) + + # logging here only. + dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx) + self.log(f"{dataloader_prefix}{eval_mode}_loss", eval_loss, sync_dist=True) + self.log(f"{dataloader_prefix}{eval_mode}_sacreBLEU", sb_score, sync_dist=True) + self.log(f"{dataloader_prefix}{eval_mode}_WER", wer_score, sync_dist=True) + + # in multi-validation case, anything after first one will become NaN + # as we are resetting the metric here. + # TODO: fix this, (not sure which hook will be ideal for this) + self.val_loss.reset() + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_validation_epoch_end(outputs, dataloader_idx, eval_mode="test") + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + Returns: + A pytorch DataLoader for the given audio file(s). + """ + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'batch_size': batch_size, + 'trim_silence': False, + 'shuffle': False, + 'num_workers': min(batch_size, os.cpu_count() - 1), + 'pin_memory': True, + 'use_lhotse': True, + 'use_bucketing': False, + 'drop_last': False, + 'text_field': 'answer', + 'lang_field': 'target_lang', + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 6d94a9e53a5e..e1bef823f5f5 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -15,7 +15,6 @@ import itertools import json import os -import re import tempfile from math import ceil from typing import Dict, List, Optional, Union @@ -28,17 +27,15 @@ from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset -from nemo.collections.asr.data.audio_to_text_canary import CanaryDataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.parts.mixins import ASRBPEMixin -from nemo.collections.asr.parts.utils import manifest_utils from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.losses import SmoothedCrossEntropyLoss from nemo.collections.common.metrics import GlobalAverageLossMetric from nemo.collections.common.parts import transformer_weights_init -from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.core.classes.common import typecheck from nemo.core.neural_types import ( AudioSignal, @@ -163,8 +160,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): len_pen=self.cfg.beam_search.len_pen, max_delta_length=self.cfg.beam_search.max_generation_delta, ) - # TO DO: remove this hardcoded context size for AR decoding - self.context_len_for_AR_decoding = 5 # Define autoregressive CE loss self.transf_loss = SmoothedCrossEntropyLoss( @@ -178,26 +173,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.val_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) - def change_decoding_strategy(self, cfg: DictConfig): - logging.info(f"Changing beam search decoding to {cfg}") - # Beam Search decoding - self.beam_search = BeamSearchSequenceGenerator( - embedding=self.transf_decoder.embedding, - decoder=self.transf_decoder.decoder, - log_softmax=self.log_softmax, - max_sequence_length=self.transf_decoder.max_sequence_length, - beam_size=cfg.beam_size, - bos=self.tokenizer.bos_id, - pad=self.tokenizer.pad_id, - eos=self.tokenizer.eos_id, - len_pen=cfg.len_pen, - max_delta_length=cfg.max_generation_delta, - ) - @torch.no_grad() def transcribe( self, - paths2audio_files: Union[List[str], str], + paths2audio_files: List[str], batch_size: int = 4, logprobs: bool = False, return_hypotheses: bool = False, @@ -227,48 +206,6 @@ def transcribe( if paths2audio_files is None or len(paths2audio_files) == 0: return {} - if return_hypotheses: - logging.warning("return_hypotheses=True is currently not supported, returning text instead.") - - if isinstance(paths2audio_files, list): - logging.info(f"Found paths2audio_files to be a list of {len(paths2audio_files)} items.") - logging.info(f"Assuming each item in paths2audio_files is a path to audio file.") - logging.info(f"Transcribing with default Canary setting of English without PnC.") - elif isinstance(paths2audio_files, str): - logging.info(f"Found paths2audio_files to be a string. Assuming it is a path to manifest file.") - assert os.path.exists(paths2audio_files), f"File {paths2audio_files} doesn't exist" - assert paths2audio_files.endswith('.json') or paths2audio_files.endswith( - '.jsonl' - ), f"File {paths2audio_files} must be a json or jsonl file" - - # load json lines - manifest_path = paths2audio_files # need to save this as we are overwriting paths2audio_files in nextline - paths2audio_files = manifest_utils.read_manifest(paths2audio_files) - - def _may_be_make_dict_and_fix_paths(json_items): - out_json_items = [] - for item in json_items: - if isinstance(item, str): - # assume it is a path to audio file - entry = { - 'audio_filepath': item, - 'duration': 100000, - 'source_lang': 'en', - 'taskname': 'asr', - 'target_lang': 'en', - 'pnc': 'no', - 'answer': 'nothing', - } - elif isinstance(item, dict): - entry = item - entry['audio_filepath'] = get_full_path(entry['audio_filepath'], manifest_file=manifest_path) - else: - raise ValueError(f"Expected str or dict, got {type(item)}") - out_json_items.append(entry) - return out_json_items - - paths2audio_files = _may_be_make_dict_and_fix_paths(paths2audio_files) - if return_hypotheses and logprobs: raise ValueError( "Either `return_hypotheses` or `logprobs` can be True at any given time." @@ -301,8 +238,8 @@ def _may_be_make_dict_and_fix_paths(json_items): with tempfile.TemporaryDirectory() as tmpdir: with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: for audio_file in paths2audio_files: - # _may_be_make_dict_and_fix_paths has already fixed the path and added other fields if needed - fp.write(json.dumps(audio_file) + '\n') + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') config = { 'paths2audio_files': paths2audio_files, @@ -323,21 +260,14 @@ def _may_be_make_dict_and_fix_paths(json_items): beam_hypotheses = ( self.beam_search( - encoder_hidden_states=enc_states, - encoder_input_mask=enc_mask, - return_beam_scores=False, - decoder_input_ids=test_batch[2][:, : self.context_len_for_AR_decoding].to(device) - if self.context_len_for_AR_decoding > 0 - else None, + encoder_hidden_states=enc_states, encoder_input_mask=enc_mask, return_beam_scores=False ) .detach() .cpu() .numpy() ) - beam_hypotheses = [ - self._strip_special_tokens(self.tokenizer.ids_to_text(hyp)) for hyp in beam_hypotheses - ] + beam_hypotheses = [self.tokenizer.ids_to_text(hyp) for hyp in beam_hypotheses] # TODO: add support for return_hypotheses=True @AlexGrinch # if return_hypotheses: @@ -367,7 +297,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): config, global_rank=self.global_rank, world_size=self.world_size, - dataset=CanaryDataset(tokenizer=self.tokenizer), + dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer,), ) dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( @@ -585,17 +515,6 @@ def training_step(self, batch, batch_nb): return {'loss': audio_loss, 'log': tensorboard_logs} - def _strip_special_tokens(self, text): - """ - assuming all special tokens are of format - Note that if any label/pred is of format , it will be stripped - """ - assert isinstance(text, str), f"Expected str, got {type(text)}" - text = re.sub(r'<[^>]+>', '', text) - # strip spaces at the beginning and end; - # this is training data artifact, will be fixed in future (@kpuvvada) - return text.strip() - def validation_step(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): signal, signal_len, transcript, transcript_len = batch input_ids, labels = transcript[:, :-1], transcript[:, 1:] @@ -616,12 +535,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): ) beam_hypotheses = self.beam_search( - encoder_hidden_states=enc_states, - encoder_input_mask=enc_mask, - return_beam_scores=False, - decoder_input_ids=input_ids[:, : self.context_len_for_AR_decoding] - if self.context_len_for_AR_decoding > 0 - else None, + encoder_hidden_states=enc_states, encoder_input_mask=enc_mask, return_beam_scores=False ) transf_loss = self.transf_loss(log_probs=transf_log_probs, labels=labels) @@ -630,16 +544,9 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): self.val_loss(loss=transf_loss, num_measurements=transf_log_probs.shape[0] * transf_log_probs.shape[1]) - output_dict = { - f'{eval_mode}_loss': transf_loss, - 'translations': [self._strip_special_tokens(t) for t in translations], - 'ground_truths': [self._strip_special_tokens(g) for g in ground_truths], - } + output_dict = {f'{eval_mode}_loss': transf_loss, 'translations': translations, 'ground_truths': ground_truths} - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(output_dict) - else: - self.validation_step_outputs.append(output_dict) + self.validation_step_outputs.append(output_dict) return output_dict @@ -692,22 +599,9 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, eval_mode sb_score = 0.0 wer_score = 0.0 - # To log via on_validation_epoch_end in modelPT.py - # remove (* self.world_size) if logging via on_validation_epoch_end - # tensorboard_logs = {} - # tensorboard_logs.update({f"{eval_mode}_loss": eval_loss}) - # tensorboard_logs.update({f"{eval_mode}_sacreBLEU": sb_score}) - # tensorboard_logs.update({f"{eval_mode}_WER": wer_score}) - - # logging here only. - dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx) - self.log(f"{dataloader_prefix}{eval_mode}_loss", eval_loss, sync_dist=True) - self.log(f"{dataloader_prefix}{eval_mode}_sacreBLEU", sb_score, sync_dist=True) - self.log(f"{dataloader_prefix}{eval_mode}_WER", wer_score, sync_dist=True) - - # in multi-validation case, anything after first one will become NaN - # as we are resetting the metric here. - # TODO: fix this, (not sure which hook will be ideal for this) + self.log(f"{eval_mode}_loss", eval_loss, sync_dist=True) + self.log(f"{eval_mode}_sacreBLEU", sb_score, sync_dist=True) + self.log(f"{eval_mode}_WER", wer_score, sync_dist=True) self.val_loss.reset() def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): @@ -742,24 +636,5 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'pin_memory': True, } - # TODO: remove this lhotse hardcoding later (@kpuvvada) - # currently only works for non-tarred - lhotse_config = { - 'is_tarred': False, - 'batch_size': 1, - 'force_strip_pnc': False, - 'use_lhotse': True, - 'lhotse': { - 'use_bucketing': False, - 'max_cuts': batch_size, - 'drop_last': False, - 'text_field': 'answer', - 'lang_field': 'target_lang', - }, - } - - # update dl_config - dl_config.update(lhotse_config) - temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) return temporary_datalayer From 3cadf70a7d3e67bcce4583c52e12547620d65545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 26 Jan 2024 15:07:16 -0500 Subject: [PATCH 3/7] Rename `CanaryDataset` to `PromptedAudioToTextLhotseDataset`; add `prompt_format_fn` argument; clean-up the `_canary_prompt_format` function a bit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_canary.py | 104 ------------- .../asr/data/audio_to_text_lhotse_prompted.py | 138 ++++++++++++++++++ .../asr/models/aed_multitask_models.py | 9 +- 3 files changed, 145 insertions(+), 106 deletions(-) delete mode 100644 nemo/collections/asr/data/audio_to_text_canary.py create mode 100644 nemo/collections/asr/data/audio_to_text_lhotse_prompted.py diff --git a/nemo/collections/asr/data/audio_to_text_canary.py b/nemo/collections/asr/data/audio_to_text_canary.py deleted file mode 100644 index e2f89bf78c2a..000000000000 --- a/nemo/collections/asr/data/audio_to_text_canary.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch.utils.data -from lhotse.cut import MixedCut, MonoCut -from lhotse.dataset import AudioSamples -from lhotse.dataset.collation import collate_vectors - -from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper - - -class CanaryDataset(torch.utils.data.Dataset): - """ - This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`. - It is a Lhotse-style dataset that converts a mini-batch of Cuts into tensors. - The main difference from ``LhotseSpeechToTextBpeDataset`` is that we introduce - a special prompt format for Canary model, which has an encoder-decoder architecture. - """ - - def __init__(self, tokenizer): - super().__init__() - self.tokenizer = TokenizerWrapper(tokenizer) - self.load_audio = AudioSamples(fault_tolerant=True) - self.padding_value = self.tokenizer._tokenizer.pad_id - - def __getitem__(self, cuts) -> tuple[torch.Tensor, ...]: - audio, audio_lens, cuts = self.load_audio(cuts) - - tokens = [self.tokenizer(c.supervisions[0].text, c.supervisions[0].language) for c in cuts] - tokens = self._canary_format(tokens, cuts) - tokens = [torch.as_tensor(t) for t in tokens] - token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) - tokens = collate_vectors(tokens, padding_value=self.padding_value) - - return audio, audio_lens, tokens, token_lens - - def _canary_format(self, tokens, cuts): - """ - prepend and append control tokens to the token sequence as per canary format - - Format: - sot, src_lang_id/no_speech, transcribe/translate, tgt_lang_id, text, eot - """ - canary_tokens = [] - for t, c in zip(tokens, cuts): - if isinstance(c, MixedCut): - c = c._first_non_padding_cut - assert isinstance(c, MonoCut), "Expected MonoCut." - - c_t = [] # canary_tokens for this cut - - # bos - c_t.append(self.tokenizer._tokenizer.bos_id) - - # if len(t) is 0 append no-speech token - if len(t) == 0: - c_t.append(self.tokenizer._tokenizer.nospeech_id) - else: - # src_lang_id/no_speech - src_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['source_lang']) - c_t.append(src_lang_id) - - # task - task = c.custom['taskname'] - if task == 'asr': - c_t.append(self.tokenizer._tokenizer.transcribe_id) - elif task == 's2t_translation': - c_t.append(self.tokenizer._tokenizer.translate_id) - else: - raise ValueError(f"Unknown task: {task}") - - # tgt_lang_id - tgt_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['target_lang']) - c_t.append(tgt_lang_id) - - # PnC - pnc = f"{c.custom['pnc']}".lower().strip() # to account for bool or str - if pnc in set(['yes', 'true']): - c_t.append(self.tokenizer._tokenizer.pnc_id) - elif pnc in set(['no', 'false']): - c_t.append(self.tokenizer._tokenizer.nopnc_id) - else: - raise ValueError(f"Unknown PnC: {pnc}") - - # text - c_t.extend(t) - - # eos - c_t.append(self.tokenizer._tokenizer.eos_id) - - canary_tokens.append(c_t) - - return canary_tokens diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py new file mode 100644 index 000000000000..56aafda33f67 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Sequence + +import torch.utils.data +from lhotse import CutSet +from lhotse.cut import MixedCut, MonoCut +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors + +from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.common.tokenizers import CanaryTokenizer + + +class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): + """ + This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`. + It is a Lhotse-style dataset that converts a mini-batch of Cuts into tensors. + The main difference from ``LhotseSpeechToTextBpeDataset`` is that we introduce + a special prompt format for multitask encoder-decoder models. + """ + + def __init__( + self, tokenizer, prompt_format_fn: Callable[[Sequence[Sequence[int]], CutSet], Sequence[Sequence[int]]] + ): + super().__init__() + self.tokenizer = TokenizerWrapper(tokenizer) + self.load_audio = AudioSamples(fault_tolerant=True) + self.padding_value = self.tokenizer._tokenizer.pad_id + self.prompt_format_fn = prompt_format_fn + + def __getitem__(self, cuts) -> tuple[torch.Tensor, ...]: + audio, audio_lens, cuts = self.load_audio(cuts) + + tokens = [self.tokenizer(c.supervisions[0].text, c.supervisions[0].language) for c in cuts] + tokens = self.prompt_format_fn(tokens, cuts) + tokens = [torch.as_tensor(t) for t in tokens] + token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) + tokens = collate_vectors(tokens, padding_value=self.padding_value) + + return audio, audio_lens, tokens, token_lens + + +def _canary_prompt_format( + tokens_batch: Sequence[Sequence[int]], cuts: CutSet, tokenizer: TokenizerWrapper +) -> Sequence[Sequence[int]]: + """ + prepend and append control tokens to the token sequence as per canary format + + Format: + sot, src_lang_id/no_speech, transcribe/translate, tgt_lang_id, text, eot + """ + + assert isinstance( + tokenizer._tokenizer, CanaryTokenizer + ), "To use 'canary' prompt format, you must use the CanaryTokenizer." + tokenizer = tokenizer._tokenizer + + canary_tokens = [] + for tokens, cut in zip(tokens_batch, cuts): + if isinstance(cut, MixedCut): + cut = cut._first_non_padding_cut + assert isinstance(cut, MonoCut), "Expected MonoCut." + + missing_keys = [k for k in ("source_lang", "target_lang", "taskname", "pnc") if k not in cut.custom] + if missing_keys: + raise RuntimeError( + f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}" + f"Please ensure that every utterance in the input manifests contains these keys." + ) + + # bos + prompted_tokens = [tokenizer.bos_id] + + if len(tokens) == 0: + # no speech token + prompted_tokens.append(tokenizer.nospeech_id) + else: + # src_lang_id/no_speech + src_lang_id = tokenizer.to_language_id(cut.custom['source_lang']) + prompted_tokens.append(src_lang_id) + + # task + task = cut.custom['taskname'] + if task == 'asr': + prompted_tokens.append(tokenizer.transcribe_id) + elif task == 's2t_translation': + prompted_tokens.append(tokenizer.translate_id) + else: + raise ValueError(f"Unknown task: {task} for cut ID: {cut.id}") + + # tgt_lang_id + tgt_lang_id = tokenizer.to_language_id(cut.custom['target_lang']) + prompted_tokens.append(tgt_lang_id) + + # PnC + pnc = f"{cut.custom['pnc']}".lower().strip() # to account for bool or str + if pnc in {'yes', 'true'}: + prompted_tokens.append(tokenizer.pnc_id) + elif pnc in {'no', 'false'}: + prompted_tokens.append(tokenizer.nopnc_id) + else: + raise ValueError(f"Unknown value for key 'pnc': {pnc} for cut ID: {cut.id}") + + # text + prompted_tokens.extend(tokens) + + # eos + prompted_tokens.append(tokenizer.eos_id) + + canary_tokens.append(prompted_tokens) + + return canary_tokens + + +# Mapping from a string name to a known prompt formatter function. +PROMPT_FORMAT_FNS = { + "canary": _canary_prompt_format, +} + + +def get_prompt_format_fn(name: str) -> Callable[[Sequence[Sequence[int]], CutSet], Sequence[Sequence[int]]]: + if name not in PROMPT_FORMAT_FNS: + raise ValueError( + f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}" + ) + return PROMPT_FORMAT_FNS[name] diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 7d4e036e22e9..337cf6b0f2de 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -27,8 +27,11 @@ from pytorch_lightning import Trainer from tqdm.auto import tqdm -from nemo.collections.asr.data.audio_to_text_canary import CanaryDataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs +from nemo.collections.asr.data.audio_to_text_lhotse_prompted import ( + PromptedAudioToTextLhotseDataset, + get_prompt_format_fn, +) from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.parts.mixins import ASRBPEMixin from nemo.collections.asr.parts.utils import manifest_utils @@ -368,7 +371,9 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): config, global_rank=self.global_rank, world_size=self.world_size, - dataset=CanaryDataset(tokenizer=self.tokenizer), + dataset=PromptedAudioToTextLhotseDataset( + tokenizer=self.tokenizer, prompt_format_fn=get_prompt_format_fn("canary"), + ), ) def setup_training_data(self, train_data_config: Optional[DictConfig]): From 3cc5bdf4b305857d93fbb28d4ae56b8403c09017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 26 Jan 2024 15:17:26 -0500 Subject: [PATCH 4/7] Move tokenization into `prompt_format_fn`, fix usage, add docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_lhotse_prompted.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 56aafda33f67..97c467b7763a 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -29,11 +29,18 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): It is a Lhotse-style dataset that converts a mini-batch of Cuts into tensors. The main difference from ``LhotseSpeechToTextBpeDataset`` is that we introduce a special prompt format for multitask encoder-decoder models. + + To perform the prompt formatting, we accept a ``prompt_format_fn``. + It's expected to accept: + * a ``CutSet`` which it will internally iterate over for utterances, and + * a ``TokenizerWrapper`` object that will be internally used to tokenize the utterances + + Tokenized utterances will be extended with special prompt tokens according to ``prompt_format_fn`` logic. + We support cuts with multiple supervision segments -- their tokenized texts will be concatenated before we add the prompt tokens. + This is useful, for example, in code-switched scenarios where each segment is spoken in a different language. """ - def __init__( - self, tokenizer, prompt_format_fn: Callable[[Sequence[Sequence[int]], CutSet], Sequence[Sequence[int]]] - ): + def __init__(self, tokenizer, prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) self.load_audio = AudioSamples(fault_tolerant=True) @@ -43,8 +50,7 @@ def __init__( def __getitem__(self, cuts) -> tuple[torch.Tensor, ...]: audio, audio_lens, cuts = self.load_audio(cuts) - tokens = [self.tokenizer(c.supervisions[0].text, c.supervisions[0].language) for c in cuts] - tokens = self.prompt_format_fn(tokens, cuts) + tokens = self.prompt_format_fn(cuts, self.tokenizer) tokens = [torch.as_tensor(t) for t in tokens] token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) tokens = collate_vectors(tokens, padding_value=self.padding_value) @@ -52,10 +58,9 @@ def __getitem__(self, cuts) -> tuple[torch.Tensor, ...]: return audio, audio_lens, tokens, token_lens -def _canary_prompt_format( - tokens_batch: Sequence[Sequence[int]], cuts: CutSet, tokenizer: TokenizerWrapper -) -> Sequence[Sequence[int]]: +def _canary_prompt_format(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence[Sequence[int]]: """ + prepend and append control tokens to the token sequence as per canary format Format: @@ -68,7 +73,7 @@ def _canary_prompt_format( tokenizer = tokenizer._tokenizer canary_tokens = [] - for tokens, cut in zip(tokens_batch, cuts): + for cut in cuts: if isinstance(cut, MixedCut): cut = cut._first_non_padding_cut assert isinstance(cut, MonoCut), "Expected MonoCut." @@ -80,6 +85,9 @@ def _canary_prompt_format( f"Please ensure that every utterance in the input manifests contains these keys." ) + # Actual tokenization. If a cut has multiple supervisions, we'll stitch their tokenized texts together. + tokens = sum((tokenizer.text_to_ids(sup.text, sup.language) for sup in cut.supervisions), start=[]) + # bos prompted_tokens = [tokenizer.bos_id] From d6884de74d7f4482f9c9cc07969b7ed809fa6df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 26 Jan 2024 15:19:14 -0500 Subject: [PATCH 5/7] Backward-compatible utterance validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_lhotse_prompted.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 97c467b7763a..7e84685d5b15 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -78,13 +78,6 @@ def _canary_prompt_format(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence cut = cut._first_non_padding_cut assert isinstance(cut, MonoCut), "Expected MonoCut." - missing_keys = [k for k in ("source_lang", "target_lang", "taskname", "pnc") if k not in cut.custom] - if missing_keys: - raise RuntimeError( - f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}" - f"Please ensure that every utterance in the input manifests contains these keys." - ) - # Actual tokenization. If a cut has multiple supervisions, we'll stitch their tokenized texts together. tokens = sum((tokenizer.text_to_ids(sup.text, sup.language) for sup in cut.supervisions), start=[]) @@ -95,6 +88,14 @@ def _canary_prompt_format(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence # no speech token prompted_tokens.append(tokenizer.nospeech_id) else: + # first, validate the utterance + missing_keys = [k for k in ("source_lang", "target_lang", "taskname", "pnc") if k not in cut.custom] + if missing_keys: + raise RuntimeError( + f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}" + f"Please ensure that every utterance in the input manifests contains these keys." + ) + # src_lang_id/no_speech src_lang_id = tokenizer.to_language_id(cut.custom['source_lang']) prompted_tokens.append(src_lang_id) From ae3dcea862841f3b7f9beb5c12495fc06699e533 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 26 Jan 2024 15:24:52 -0500 Subject: [PATCH 6/7] Improve type annotations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../collections/asr/data/audio_to_text_lhotse_prompted.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 7e84685d5b15..ce3200272e00 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -20,7 +20,7 @@ from lhotse.dataset.collation import collate_vectors from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper -from nemo.collections.common.tokenizers import CanaryTokenizer +from nemo.collections.common.tokenizers import CanaryTokenizer, TokenizerSpec class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): @@ -40,14 +40,16 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): This is useful, for example, in code-switched scenarios where each segment is spoken in a different language. """ - def __init__(self, tokenizer, prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]): + def __init__( + self, tokenizer: TokenizerSpec, prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]] + ): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) self.load_audio = AudioSamples(fault_tolerant=True) self.padding_value = self.tokenizer._tokenizer.pad_id self.prompt_format_fn = prompt_format_fn - def __getitem__(self, cuts) -> tuple[torch.Tensor, ...]: + def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: audio, audio_lens, cuts = self.load_audio(cuts) tokens = self.prompt_format_fn(cuts, self.tokenizer) From 7743e2f4a2aedf9179c8c23e7fcdd7c59b6a60ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 26 Jan 2024 17:07:28 -0500 Subject: [PATCH 7/7] config and prompt_fn registration changes from review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_multitask/fast-conformer_aed.yaml | 48 ++++++++++--------- .../asr/data/audio_to_text_lhotse_prompted.py | 47 ++++++++++++------ 2 files changed, 58 insertions(+), 37 deletions(-) diff --git a/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml index 6283b364bec4..c360c16225de 100644 --- a/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml +++ b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml @@ -22,14 +22,16 @@ model: log_prediction: true # enables logging sample predictions in the output during training train_ds: - is_tarred: true + use_lhotse: true tarred_audio_filepaths: ??? manifest_filepath: ??? - sample_rate: 16000 + sample_rate: ${model.sample_rate} shuffle: true - trim_silence: false num_workers: 8 - use_lhotse: true + # To understand the settings below, please refer to Lhotse Dataloading documentation: + # https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/datasets.rst#lhotse-dataloading + # You can also check the following configuration dataclass: + # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/lhotse/dataloader.py#L36 batch_size: None batch_duration: 360 quadratic_duration: 20 @@ -39,6 +41,7 @@ model: shuffle_buffer_size: 10000 validation_ds: + use_lhotse: true manifest_filepath: ??? sample_rate: ${model.sample_rate} batch_size: 8 # you may increase batch_size if your memory allows @@ -46,10 +49,11 @@ model: num_workers: 4 pin_memory: true use_start_end_token: true - use_lhotse: true use_bucketing: false + drop_last: false test_ds: + use_lhotse: true manifest_filepath: ??? sample_rate: ${model.sample_rate} batch_size: 8 # you may increase batch_size if your memory allows @@ -57,8 +61,8 @@ model: num_workers: 4 pin_memory: true use_start_end_token: true - use_lhotse: true use_bucketing: false + drop_last: false # recommend small vocab size of 128 or 256 when using 4x sub-sampling # you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py @@ -93,8 +97,8 @@ model: _target_: nemo.collections.asr.modules.ConformerEncoder feat_in: ${model.preprocessor.features} feat_out: -1 # you may set it if you need different output size other than the default d_model - n_layers: 17 - d_model: 512 + n_layers: 24 + d_model: 1024 # Sub-sampling params subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory @@ -113,7 +117,7 @@ model: n_heads: 8 # may need to be lower for smaller d_models # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention att_context_size: [-1, -1] # -1 means unlimited context - xscaling: true # scales up the input embeddings by sqrt(d_model) + xscaling: false # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers pos_emb_max_len: 5000 @@ -145,10 +149,10 @@ model: num_token_types: 0 embedding_dropout: 0.1 learn_positional_encodings: false - hidden_size: 512 - inner_size: 2048 - num_layers: 6 - num_attention_heads: 4 + hidden_size: 1024 + inner_size: 4096 + num_layers: 24 + num_attention_heads: 8 ffn_dropout: 0.1 attn_score_dropout: 0.1 attn_layer_dropout: 0.1 @@ -164,34 +168,34 @@ model: use_transformer_init: true beam_search: - beam_size: 4 + beam_size: 1 len_pen: 0.0 max_generation_delta: 50 optim: - name: adam - lr: 0.0001 + name: adamw + lr: 3e-4 # optimizer arguments betas: [0.9, 0.98] # less necessity for weight_decay as we already have large augmentations with SpecAug # you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used # weight decay of 0.0 with lr of 2.0 also works fine - #weight_decay: 1e-3 + weight_decay: 1e-3 # scheduler setup sched: name: InverseSquareRootAnnealing #d_model: ${model.encoder.d_model} # scheduler config override - warmup_steps: 1000 + warmup_steps: 2500 warmup_ratio: null min_lr: 1e-6 trainer: devices: -1 # number of GPUs, -1 would use all available GPUs num_nodes: 1 - max_epochs: 100 - max_steps: -1 # computed at runtime if not set + max_epochs: -1 + max_steps: 100000 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: ddp @@ -200,7 +204,7 @@ trainer: precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP. log_every_n_steps: 100 # Interval of logging. enable_progress_bar: True - num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + num_sanity_val_steps: 2 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs sync_batchnorm: true enable_checkpointing: False # Provided by exp_manager @@ -220,7 +224,7 @@ exp_manager: resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. # you need to set these two to True to continue the training - resume_if_exists: false + resume_if_exists: true resume_ignore_no_checkpoint: false # You may use this section to create a W&B logger diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index ce3200272e00..f8f34a582c20 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -60,7 +60,38 @@ def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.T return audio, audio_lens, tokens, token_lens -def _canary_prompt_format(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence[Sequence[int]]: +# Mapping from a string name to a known prompt formatter function. +PROMPT_FORMAT_FNS = {} + + +def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]): + """ + Decorator for registering prompt functions under a name. + + Example:: + + >>> @registered_prompt_format_fn + ... def my_prompt(cuts, tokenizer): + ... pass + ... + ... prompt_fn = get_prompt_format_fn("my_prompt") + """ + global PROMPT_FORMAT_FNS + + PROMPT_FORMAT_FNS[prompt_fn.__name__] = prompt_fn + return prompt_fn + + +def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]: + if name not in PROMPT_FORMAT_FNS: + raise ValueError( + f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}" + ) + return PROMPT_FORMAT_FNS[name] + + +@registered_prompt_format_fn +def canary(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence[Sequence[int]]: """ prepend and append control tokens to the token sequence as per canary format @@ -133,17 +164,3 @@ def _canary_prompt_format(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence canary_tokens.append(prompted_tokens) return canary_tokens - - -# Mapping from a string name to a known prompt formatter function. -PROMPT_FORMAT_FNS = { - "canary": _canary_prompt_format, -} - - -def get_prompt_format_fn(name: str) -> Callable[[Sequence[Sequence[int]], CutSet], Sequence[Sequence[int]]]: - if name not in PROMPT_FORMAT_FNS: - raise ValueError( - f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}" - ) - return PROMPT_FORMAT_FNS[name]