Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[canary] Refactor: PromptedAudioToTextLhotseDataset and EncDecMultiTaskModel #8247

Merged
merged 7 commits into from
Jan 26, 2024
234 changes: 234 additions & 0 deletions examples/asr/conf/speech_multitask/fast-conformer_aed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# It contains the default values for training an autoregressive FastConformer-Transformer ST model with sub-word encoding.

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file.
# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes
# It is recommended to initialize FastConformer with ASR pre-trained encoder for better accuracy and faster convergence

name: "FastConformer-Transformer-MultiTask"

# Initialize model encoder with pre-trained ASR FastConformer encoder for faster convergence and improved accuracy
init_from_nemo_model:
model0:
path: ???
include: ["preprocessor", "encoder"]

model:
_target_: nemo.collections.asr.models.EncDecMultiTaskModel
sample_rate: 16000
label_smoothing: 0.0
log_prediction: true # enables logging sample predictions in the output during training

train_ds:
use_lhotse: true
tarred_audio_filepaths: ???
manifest_filepath: ???
sample_rate: ${model.sample_rate}
shuffle: true
num_workers: 8
# To understand the settings below, please refer to Lhotse Dataloading documentation:
# https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/datasets.rst#lhotse-dataloading
# You can also check the following configuration dataclass:
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/lhotse/dataloader.py#L36
batch_size: None
batch_duration: 360
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
quadratic_duration: 20
use_bucketing: True
num_buckets: 20
bucket_buffer_size: 20000
shuffle_buffer_size: 10000

validation_ds:
use_lhotse: true
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true
use_bucketing: false
drop_last: false

test_ds:
use_lhotse: true
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true
use_bucketing: false
drop_last: false

# recommend small vocab size of 128 or 256 when using 4x sub-sampling
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about custom_tokenizer: subconfig?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add in canary-2


preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 0
pad_value: 0.0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
# you may use lower time_masks for smaller models to have a faster convergence
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 24
d_model: 1024

# Sub-sampling params
subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory
subsampling_factor: 8 # must be power of 2
subsampling_conv_channels: 256 # -1 sets it to d_model
causal_downsampling: false
reduction: null
reduction_position: null
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: false # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: batch_norm
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

transf_encoder:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a transformer encoder when we have a fast conformed encoder?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the unused "extra" transformer encoder between the actual encoder and the decoder that @AlexGrinch mentioned during the presentation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh? Do our current models need it? If not throw it out. We can create a script to drop the unused parameters out of current models.

num_layers: 0
hidden_size: 512
inner_size: 2048
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1

transf_decoder:
library: nemo
model_name: null
pretrained: false
max_sequence_length: 512
num_token_types: 0
embedding_dropout: 0.1
learn_positional_encodings: false
hidden_size: 1024
inner_size: 4096
num_layers: 24
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1
hidden_act: relu
pre_ln: true
pre_ln_final_layer_norm: true

head:
num_layers: 1
activation: relu
log_softmax: true
dropout: 0.0
use_transformer_init: true

beam_search:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: @titu1994 modify this inside of decoding config in my PR.

beam_size: 1
len_pen: 0.0
max_generation_delta: 50

optim:
name: adamw
lr: 3e-4
# optimizer arguments
betas: [0.9, 0.98]
# less necessity for weight_decay as we already have large augmentations with SpecAug
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
# weight decay of 0.0 with lr of 2.0 also works fine
weight_decay: 1e-3

# scheduler setup
sched:
name: InverseSquareRootAnnealing
#d_model: ${model.encoder.d_model}
# scheduler config override
warmup_steps: 2500
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: 100000 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 100 # Interval of logging.
enable_progress_bar: True
num_sanity_val_steps: 2 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_sacreBLEU"
mode: "max"
save_top_k: 3
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to True to continue the training
resume_if_exists: true
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
80 changes: 3 additions & 77 deletions nemo/collections/asr/data/audio_to_text_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Dict, Optional, Tuple

import torch.utils.data
from lhotse.cut import MixedCut, MonoCut
from lhotse.dataset import AudioSamples
from lhotse.dataset.collation import collate_vectors

Expand Down Expand Up @@ -44,91 +43,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
}

def __init__(self, tokenizer, token_sequence_format: str = None):
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = TokenizerWrapper(tokenizer)
self.load_audio = AudioSamples(fault_tolerant=True)
assert token_sequence_format is None or token_sequence_format in [
'canary'
], f"Unsupported token_sequence_format: {token_sequence_format}"
self.token_sequence_format = token_sequence_format

def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
audio, audio_lens, cuts = self.load_audio(cuts)

tokens = [self.tokenizer(c.supervisions[0].text, c.supervisions[0].language) for c in cuts]
if self.token_sequence_format == 'canary':
tokens = self._canary_format(tokens, cuts)
tokens = [torch.as_tensor(t) for t in tokens]

tokens = [torch.as_tensor(self.tokenizer(c.supervisions[0].text, c.supervisions[0].language)) for c in cuts]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)

if self.token_sequence_format == 'canary':
padding_value = self.tokenizer._tokenizer.pad_id
else:
padding_value = 0
tokens = collate_vectors(tokens, padding_value=padding_value)

tokens = collate_vectors(tokens, padding_value=0)
return audio, audio_lens, tokens, token_lens

def _canary_format(self, tokens, cuts):
"""
prepend and append control tokens to the token sequence as per canary format

Format:
sot, src_lang_id/no_speech, transcribe/translate, tgt_lang_id, text, eot
"""
canary_tokens = []
for t, c in zip(tokens, cuts):
if isinstance(c, MixedCut):
c = c._first_non_padding_cut
assert isinstance(c, MonoCut), "Expected MonoCut."

c_t = [] # canary_tokens for this cut

# bos
c_t.append(self.tokenizer._tokenizer.bos_id)

# if len(t) is 0 append no-speech token
if len(t) == 0:
c_t.append(self.tokenizer._tokenizer.nospeech_id)
else:
# src_lang_id/no_speech
src_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['source_lang'])
c_t.append(src_lang_id)

# task
task = c.custom['taskname']
if task == 'asr':
c_t.append(self.tokenizer._tokenizer.transcribe_id)
elif task == 's2t_translation':
c_t.append(self.tokenizer._tokenizer.translate_id)
else:
raise ValueError(f"Unknown task: {task}")

# tgt_lang_id
tgt_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['target_lang'])
c_t.append(tgt_lang_id)

# PnC
pnc = f"{c.custom['pnc']}".lower().strip() # to account for bool or str
if pnc in set(['yes', 'true']):
c_t.append(self.tokenizer._tokenizer.pnc_id)
elif pnc in set(['no', 'false']):
c_t.append(self.tokenizer._tokenizer.nopnc_id)
else:
raise ValueError(f"Unknown PnC: {pnc}")

# text
c_t.extend(t)

# eos
c_t.append(self.tokenizer._tokenizer.eos_id)

canary_tokens.append(c_t)

return canary_tokens


class TokenizerWrapper:
"""
Expand Down
Loading
Loading