From ebda0ec934616c2af77fd27db3735d44964de5ed Mon Sep 17 00:00:00 2001 From: hwangjeff Date: Wed, 9 Feb 2022 20:22:40 -0800 Subject: [PATCH] Refactor Emformer RNNT recipes (#2212) Summary: Consolidates LibriSpeech and TED-LIUM Release 3 Emformer RNN-T training recipes in a single directory. Pull Request resolved: https://github.com/pytorch/audio/pull/2212 Reviewed By: mthrok Differential Revision: D34120104 Pulled By: hwangjeff fbshipit-source-id: 9ed0a5d5b209478a841324f360c5b66268e3b228 --- examples/asr/emformer_rnnt/README.md | 65 +++++++++ examples/asr/emformer_rnnt/common.py | 98 ++++++++++++++ examples/asr/emformer_rnnt/eval.py | 103 ++++++++++++++ .../global_stats.py | 36 +++-- .../librispeech}/global_stats.json | 0 .../librispeech}/lightning.py | 127 +++--------------- .../librispeech}/pipeline_demo.py | 0 .../tedlium3}/eval_pipeline.py | 0 .../tedlium3}/global_stats.json | 0 .../tedlium3}/lightning.py | 123 +++++------------ .../train.py | 88 ++++++------ .../asr/librispeech_emformer_rnnt/README.md | 35 ----- .../asr/librispeech_emformer_rnnt/eval.py | 76 ----------- .../asr/librispeech_emformer_rnnt/train.py | 98 -------------- .../asr/librispeech_emformer_rnnt/utils.py | 15 --- examples/asr/tedlium3_emformer_rnnt/README.md | 42 ------ .../compute_global_stats.py | 80 ----------- examples/asr/tedlium3_emformer_rnnt/eval.py | 101 -------------- .../asr/tedlium3_emformer_rnnt/train_spm.py | 81 ----------- examples/asr/tedlium3_emformer_rnnt/utils.py | 1 - 20 files changed, 390 insertions(+), 779 deletions(-) create mode 100644 examples/asr/emformer_rnnt/README.md create mode 100644 examples/asr/emformer_rnnt/common.py create mode 100644 examples/asr/emformer_rnnt/eval.py rename examples/asr/{librispeech_emformer_rnnt => emformer_rnnt}/global_stats.py (59%) rename examples/asr/{librispeech_emformer_rnnt => emformer_rnnt/librispeech}/global_stats.json (100%) rename examples/asr/{librispeech_emformer_rnnt => emformer_rnnt/librispeech}/lightning.py (67%) rename examples/asr/{librispeech_emformer_rnnt => emformer_rnnt/librispeech}/pipeline_demo.py (100%) rename examples/asr/{tedlium3_emformer_rnnt => emformer_rnnt/tedlium3}/eval_pipeline.py (100%) rename examples/asr/{tedlium3_emformer_rnnt => emformer_rnnt/tedlium3}/global_stats.json (100%) rename examples/asr/{tedlium3_emformer_rnnt => emformer_rnnt/tedlium3}/lightning.py (70%) rename examples/asr/{tedlium3_emformer_rnnt => emformer_rnnt}/train.py (57%) delete mode 100644 examples/asr/librispeech_emformer_rnnt/README.md delete mode 100644 examples/asr/librispeech_emformer_rnnt/eval.py delete mode 100644 examples/asr/librispeech_emformer_rnnt/train.py delete mode 100644 examples/asr/librispeech_emformer_rnnt/utils.py delete mode 100644 examples/asr/tedlium3_emformer_rnnt/README.md delete mode 100644 examples/asr/tedlium3_emformer_rnnt/compute_global_stats.py delete mode 100644 examples/asr/tedlium3_emformer_rnnt/eval.py delete mode 100644 examples/asr/tedlium3_emformer_rnnt/train_spm.py delete mode 120000 examples/asr/tedlium3_emformer_rnnt/utils.py diff --git a/examples/asr/emformer_rnnt/README.md b/examples/asr/emformer_rnnt/README.md new file mode 100644 index 0000000000..8cd04912c4 --- /dev/null +++ b/examples/asr/emformer_rnnt/README.md @@ -0,0 +1,65 @@ +# Emformer RNN-T ASR Example + +This directory contains sample implementations of training and evaluation pipelines for an Emformer RNN-T streaming ASR model. + +## Usage + +### Training + +[`train.py`](./train.py) trains an Emformer RNN-T model using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to datasets and the SentencePiece model to be used to encode targets. The script also expects a file (--global_stats_path) that contains training set feature statistics; this file can be generated via [`global_stats.py`](./global_stats.py). + +### Evaluation + +[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on a given dataset. + +## Model Types + +Currently, we have training recipes for the LibriSpeech and TED-LIUM Release 3 datasets. + +### LibriSpeech + +Sample SLURM command for training: +``` +srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --model_type librispeech --exp_dir ./experiments --dataset_path ./datasets/librispeech --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_4096.model +``` + +Sample SLURM command for evaluation: +``` +srun python eval.py --model_type librispeech --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --dataset_path ./datasets/librispeech --sp_model_path ./spm_bpe_4096.model --use_cuda +``` + +Using the sample training command above along with a SentencePiece model trained on LibriSpeech with vocab size 4096 and type bpe, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0456 when evaluated on test-clean with [`eval.py`](./eval.py). + +The table below contains WER results for various splits. + +| | WER | +|:-------------------:|-------------:| +| test-clean | 0.0456 | +| test-other | 0.1066 | +| dev-clean | 0.0415 | +| dev-other | 0.1110 | + +[`librispeech/pipeline_demo.py`](./librispeech/pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH` bundle that wraps a pre-trained Emformer RNN-T produced by the above recipe to perform streaming and full-context ASR on several LibriSpeech samples. + +### TED-LIUM Release 3 + +Whereas the LibriSpeech model is configured with a vocabulary size of 4096, the TED-LIUM Release 3 model is configured with a vocabulary size of 500. Consequently, the TED-LIUM Release 3 model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol); the rest of the model is identical to the LibriSpeech model. + +Sample SLURM command for training: +``` +srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --model_type tedlium3 --exp_dir ./experiments --dataset_path ./datasets/tedlium --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_500.model --gradient_clip_val 5.0 +``` + +Sample SLURM command for evaluation: +``` +srun python eval.py --model_type tedlium3 --checkpoint_path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium_path ./datasets/tedlium --sp_model_path ./spm-bpe-500.model --use_cuda +``` + +The table below contains WER results for dev and test subsets of TED-LIUM release 3. + +| | WER | +|:-----------:|-------------:| +| dev | 0.108 | +| test | 0.098 | + +[`tedlium3/eval_pipeline.py`](./tedlium3/eval_pipeline.py) evaluates the pre-trained `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3. Running the script should produce WER results that are identical to those in the above table. diff --git a/examples/asr/emformer_rnnt/common.py b/examples/asr/emformer_rnnt/common.py new file mode 100644 index 0000000000..36977fcb07 --- /dev/null +++ b/examples/asr/emformer_rnnt/common.py @@ -0,0 +1,98 @@ +import json +import math +from collections import namedtuple +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torchaudio +from torchaudio.models import Hypothesis + + +MODEL_TYPE_LIBRISPEECH = "librispeech" +MODEL_TYPE_TEDLIUM3 = "tedlium3" + + +DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) +GAIN = pow(10, 0.05 * DECIBEL) +spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160) + +Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) + + +def piecewise_linear_log(x): + x[x > math.e] = torch.log(x[x > math.e]) + x[x <= math.e] = x[x <= math.e] / math.e + return x + + +def batch_by_token_count(idx_target_lengths, token_limit): + batches = [] + current_batch = [] + current_token_count = 0 + for idx, target_length in idx_target_lengths: + if current_token_count + target_length > token_limit: + batches.append(current_batch) + current_batch = [idx] + current_token_count = target_length + else: + current_batch.append(idx) + current_token_count += target_length + + if current_batch: + batches.append(current_batch) + + return batches + + +def post_process_hypos( + hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor +) -> List[Tuple[str, float, List[int], List[int]]]: + post_process_remove_list = [ + sp_model.unk_id(), + sp_model.eos_id(), + sp_model.pad_id(), + ] + filtered_hypo_tokens = [ + [token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos + ] + hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] + hypos_ali = [h.alignment[1:] for h in hypos] + hypos_ids = [h.tokens[1:] for h in hypos] + hypos_score = [[math.exp(h.score)] for h in hypos] + + nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids)) + + return nbest_batch + + +class FunctionalModule(torch.nn.Module): + def __init__(self, functional): + super().__init__() + self.functional = functional + + def forward(self, input): + return self.functional(input) + + +class GlobalStatsNormalization(torch.nn.Module): + def __init__(self, global_stats_path): + super().__init__() + + with open(global_stats_path) as f: + blob = json.loads(f.read()) + + self.mean = torch.tensor(blob["mean"]) + self.invstddev = torch.tensor(blob["invstddev"]) + + def forward(self, input): + return (input - self.mean) * self.invstddev + + +class WarmupLR(torch.optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False): + self.warmup_updates = warmup_updates + super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) + + def get_lr(self): + return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs] diff --git a/examples/asr/emformer_rnnt/eval.py b/examples/asr/emformer_rnnt/eval.py new file mode 100644 index 0000000000..c5cd004aca --- /dev/null +++ b/examples/asr/emformer_rnnt/eval.py @@ -0,0 +1,103 @@ +import logging +import pathlib +from argparse import ArgumentParser + +import torch +import torchaudio +from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3 +from librispeech.lightning import LibriSpeechRNNTModule +from tedlium3.lightning import TEDLIUM3RNNTModule + + +logger = logging.getLogger() + + +def compute_word_level_distance(seq1, seq2): + return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split()) + + +def run_eval(model): + total_edit_distance = 0 + total_length = 0 + dataloader = model.test_dataloader() + with torch.no_grad(): + for idx, (batch, transcripts) in enumerate(dataloader): + actual = transcripts[0] + predicted = model(batch) + total_edit_distance += compute_word_level_distance(actual, predicted) + total_length += len(actual.split()) + if idx % 100 == 0: + logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}") + logger.info(f"Final WER: {total_edit_distance / total_length}") + + +def get_lightning_module(args): + if args.model_type == MODEL_TYPE_LIBRISPEECH: + return LibriSpeechRNNTModule.load_from_checkpoint( + args.checkpoint_path, + librispeech_path=str(args.dataset_path), + sp_model_path=str(args.sp_model_path), + global_stats_path=str(args.global_stats_path), + ) + elif args.model_type == MODEL_TYPE_TEDLIUM3: + return TEDLIUM3RNNTModule.load_from_checkpoint( + args.checkpoint_path, + tedlium_path=str(args.dataset_path), + sp_model_path=str(args.sp_model_path), + global_stats_path=str(args.global_stats_path), + ) + else: + raise ValueError(f"Encountered unsupported model type {args.model_type}.") + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True) + parser.add_argument( + "--checkpoint_path", + type=pathlib.Path, + help="Path to checkpoint to use for evaluation.", + ) + parser.add_argument( + "--global_stats_path", + default=pathlib.Path("global_stats.json"), + type=pathlib.Path, + help="Path to JSON file containing feature means and stddevs.", + ) + parser.add_argument( + "--dataset_path", + type=pathlib.Path, + help="Path to dataset.", + ) + parser.add_argument( + "--sp_model_path", + type=pathlib.Path, + help="Path to SentencePiece model.", + ) + parser.add_argument( + "--use_cuda", + action="store_true", + default=False, + help="Run using CUDA.", + ) + parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") + return parser.parse_args() + + +def init_logger(debug): + fmt = "%(asctime)s %(message)s" if debug else "%(message)s" + level = logging.DEBUG if debug else logging.INFO + logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") + + +def cli_main(): + args = parse_args() + init_logger(args.debug) + model = get_lightning_module(args) + if args.use_cuda: + model = model.to(device="cuda") + run_eval(model) + + +if __name__ == "__main__": + cli_main() diff --git a/examples/asr/librispeech_emformer_rnnt/global_stats.py b/examples/asr/emformer_rnnt/global_stats.py similarity index 59% rename from examples/asr/librispeech_emformer_rnnt/global_stats.py rename to examples/asr/emformer_rnnt/global_stats.py index b5275a1bf5..820581429d 100644 --- a/examples/asr/librispeech_emformer_rnnt/global_stats.py +++ b/examples/asr/emformer_rnnt/global_stats.py @@ -1,7 +1,7 @@ -"""Generate feature statistics for LibriSpeech training set. +"""Generate feature statistics for training set. Example: -python global_stats.py --librispeech_path /home/librispeech +python global_stats.py --model_type librispeech --dataset_path /home/librispeech """ import json @@ -11,19 +11,20 @@ import torch import torchaudio -from utils import GAIN, piecewise_linear_log, spectrogram_transform +from common import GAIN, MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, piecewise_linear_log, spectrogram_transform logger = logging.getLogger() def parse_args(): parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) + parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True) parser.add_argument( - "--librispeech_path", + "--dataset_path", required=True, type=pathlib.Path, - help="Path to LibriSpeech datasets. " - "All of 'train-clean-360', 'train-clean-100', and 'train-other-500' must exist.", + help="Path to dataset. " + "For LibriSpeech, all of 'train-clean-360', 'train-clean-100', and 'train-other-500' must exist.", ) parser.add_argument( "--output_path", @@ -56,15 +57,24 @@ def generate_statistics(samples): return E_x, (E_x_2 - E_x ** 2) ** 0.5 +def get_dataset(args): + if args.model_type == MODEL_TYPE_LIBRISPEECH: + return torch.utils.data.ConcatDataset( + [ + torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-clean-360"), + torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-clean-100"), + torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-other-500"), + ] + ) + elif args.model_type == MODEL_TYPE_TEDLIUM3: + return torchaudio.datasets.TEDLIUM(args.dataset_path, release="release3", subset="train") + else: + raise ValueError(f"Encountered unsupported model type {args.model_type}.") + + def cli_main(): args = parse_args() - dataset = torch.utils.data.ConcatDataset( - [ - torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-clean-360"), - torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-clean-100"), - torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-other-500"), - ] - ) + dataset = get_dataset(args) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) mean, stddev = generate_statistics(iter(dataloader)) diff --git a/examples/asr/librispeech_emformer_rnnt/global_stats.json b/examples/asr/emformer_rnnt/librispeech/global_stats.json similarity index 100% rename from examples/asr/librispeech_emformer_rnnt/global_stats.json rename to examples/asr/emformer_rnnt/librispeech/global_stats.json diff --git a/examples/asr/librispeech_emformer_rnnt/lightning.py b/examples/asr/emformer_rnnt/librispeech/lightning.py similarity index 67% rename from examples/asr/librispeech_emformer_rnnt/lightning.py rename to examples/asr/emformer_rnnt/librispeech/lightning.py index 58ff7f285b..c70af210b5 100644 --- a/examples/asr/librispeech_emformer_rnnt/lightning.py +++ b/examples/asr/emformer_rnnt/librispeech/lightning.py @@ -1,42 +1,26 @@ -import json -import math import os -from collections import namedtuple -from typing import List, Tuple +from typing import List import sentencepiece as spm import torch import torchaudio -import torchaudio.functional as F +from common import ( + GAIN, + Batch, + FunctionalModule, + GlobalStatsNormalization, + WarmupLR, + batch_by_token_count, + piecewise_linear_log, + post_process_hypos, + spectrogram_transform, +) from pytorch_lightning import LightningModule -from torchaudio.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base -from utils import GAIN, piecewise_linear_log, spectrogram_transform - - -Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) - - -def _batch_by_token_count(idx_target_lengths, token_limit): - batches = [] - current_batch = [] - current_token_count = 0 - for idx, target_length in idx_target_lengths: - if current_token_count + target_length > token_limit: - batches.append(current_batch) - current_batch = [idx] - current_token_count = target_length - else: - current_batch.append(idx) - current_token_count += target_length - - if current_batch: - batches.append(current_batch) - - return batches +from torchaudio.models import RNNTBeamSearch, emformer_rnnt_base class CustomDataset(torch.utils.data.Dataset): - r"""Sort samples by target length and batch to max token count.""" + r"""Sort LibriSpeech samples by target length and batch to max token count.""" def __init__(self, base_dataset, max_token_limit): super().__init__() @@ -54,7 +38,7 @@ def __init__(self, base_dataset, max_token_limit): assert max_token_limit >= idx_target_lengths[0][1] - self.batches = _batch_by_token_count(idx_target_lengths, max_token_limit) + self.batches = batch_by_token_count(idx_target_lengths, max_token_limit) def _target_length(self, fileid, fileid_to_target_length): if fileid not in fileid_to_target_length: @@ -77,74 +61,7 @@ def __len__(self): return len(self.batches) -class TimeMasking(torchaudio.transforms._AxisMasking): - def __init__(self, time_mask_param: int, min_mask_p: float, iid_masks: bool = False) -> None: - super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) - self.min_mask_p = min_mask_p - - def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor: - if self.iid_masks and specgram.dim() == 4: - mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis + 1]) - return F.mask_along_axis_iid(specgram, mask_param, mask_value, self.axis + 1) - else: - mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis]) - return F.mask_along_axis(specgram, mask_param, mask_value, self.axis) - - -class FunctionalModule(torch.nn.Module): - def __init__(self, functional): - super().__init__() - self.functional = functional - - def forward(self, input): - return self.functional(input) - - -class GlobalStatsNormalization(torch.nn.Module): - def __init__(self, global_stats_path): - super().__init__() - - with open(global_stats_path) as f: - blob = json.loads(f.read()) - - self.mean = torch.tensor(blob["mean"]) - self.invstddev = torch.tensor(blob["invstddev"]) - - def forward(self, input): - return (input - self.mean) * self.invstddev - - -class WarmupLR(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False): - self.warmup_updates = warmup_updates - super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) - - def get_lr(self): - return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs] - - -def post_process_hypos( - hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor -) -> List[Tuple[str, float, List[int], List[int]]]: - post_process_remove_list = [ - sp_model.unk_id(), - sp_model.eos_id(), - sp_model.pad_id(), - ] - filtered_hypo_tokens = [ - [token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos - ] - hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] - hypos_ali = [h.alignment[1:] for h in hypos] - hypos_ids = [h.tokens[1:] for h in hypos] - hypos_score = [[math.exp(h.score)] for h in hypos] - - nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids)) - - return nbest_batch - - -class RNNTModule(LightningModule): +class LibriSpeechRNNTModule(LightningModule): def __init__( self, *, @@ -157,7 +74,6 @@ def __init__( self.model = emformer_rnnt_base(num_symbols=4097) self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8) - self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.96, patience=0) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000) self.train_data_pipeline = torch.nn.Sequential( @@ -166,8 +82,8 @@ def __init__( FunctionalModule(lambda x: x.transpose(1, 2)), torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27), - TimeMasking(100, 0.2), - TimeMasking(100, 0.2), + torchaudio.transforms.TimeMasking(100, p=0.2), + torchaudio.transforms.TimeMasking(100, p=0.2), FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(lambda x: x.transpose(1, 2)), ) @@ -219,7 +135,7 @@ def _valid_collate_fn(self, samples: List): return Batch(features, feature_lengths, targets, target_lengths) def _test_collate_fn(self, samples: List): - return self._valid_collate_fn(samples), samples + return self._valid_collate_fn(samples), [sample[2] for sample in samples] def _step(self, batch, batch_idx, step_type): if batch is None: @@ -243,11 +159,6 @@ def configure_optimizers(self): return ( [self.optimizer], [ - { - "scheduler": self.lr_scheduler, - "monitor": "Losses/val_loss", - "interval": "epoch", - }, {"scheduler": self.warmup_lr_scheduler, "interval": "step"}, ], ) diff --git a/examples/asr/librispeech_emformer_rnnt/pipeline_demo.py b/examples/asr/emformer_rnnt/librispeech/pipeline_demo.py similarity index 100% rename from examples/asr/librispeech_emformer_rnnt/pipeline_demo.py rename to examples/asr/emformer_rnnt/librispeech/pipeline_demo.py diff --git a/examples/asr/tedlium3_emformer_rnnt/eval_pipeline.py b/examples/asr/emformer_rnnt/tedlium3/eval_pipeline.py similarity index 100% rename from examples/asr/tedlium3_emformer_rnnt/eval_pipeline.py rename to examples/asr/emformer_rnnt/tedlium3/eval_pipeline.py diff --git a/examples/asr/tedlium3_emformer_rnnt/global_stats.json b/examples/asr/emformer_rnnt/tedlium3/global_stats.json similarity index 100% rename from examples/asr/tedlium3_emformer_rnnt/global_stats.json rename to examples/asr/emformer_rnnt/tedlium3/global_stats.json diff --git a/examples/asr/tedlium3_emformer_rnnt/lightning.py b/examples/asr/emformer_rnnt/tedlium3/lightning.py similarity index 70% rename from examples/asr/tedlium3_emformer_rnnt/lightning.py rename to examples/asr/emformer_rnnt/tedlium3/lightning.py index b9589c8050..dd00cbb9ca 100644 --- a/examples/asr/tedlium3_emformer_rnnt/lightning.py +++ b/examples/asr/emformer_rnnt/tedlium3/lightning.py @@ -1,43 +1,26 @@ -import json -import math import os -from collections import namedtuple -from typing import List, Tuple +from typing import List import sentencepiece as spm import torch import torchaudio +from common import ( + GAIN, + Batch, + FunctionalModule, + GlobalStatsNormalization, + WarmupLR, + batch_by_token_count, + piecewise_linear_log, + post_process_hypos, + spectrogram_transform, +) from pytorch_lightning import LightningModule -from torchaudio.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base -from torchaudio.transforms import TimeMasking -from utils import GAIN, piecewise_linear_log, spectrogram_transform - -Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) - - -def _batch_by_token_count(idx_target_lengths, token_limit): - batches = [] - current_batch = [] - current_token_count = 0 - for idx, target_length in idx_target_lengths: - if target_length == -1: - continue - if current_token_count + target_length > token_limit: - batches.append(current_batch) - current_batch = [idx] - current_token_count = target_length - else: - current_batch.append(idx) - current_token_count += target_length - - if current_batch: - batches.append(current_batch) - - return batches +from torchaudio.models import RNNTBeamSearch, emformer_rnnt_base class CustomDataset(torch.utils.data.Dataset): - r"""Sort samples by target length and batch to max durations.""" + r"""Sort TEDLIUM3 samples by target length and batch to max durations.""" def __init__(self, base_dataset, max_token_limit): super().__init__() @@ -46,6 +29,7 @@ def __init__(self, base_dataset, max_token_limit): idx_target_lengths = [ (idx, self._target_length(fileid, line)) for idx, (fileid, line) in enumerate(self.base_dataset._filelist) ] + idx_target_lengths = [(idx, length) for idx, length in idx_target_lengths if length != -1] assert len(idx_target_lengths) > 0 @@ -53,13 +37,13 @@ def __init__(self, base_dataset, max_token_limit): assert max_token_limit >= idx_target_lengths[-1][1] - self.batches = _batch_by_token_count(idx_target_lengths, max_token_limit) + self.batches = batch_by_token_count(idx_target_lengths, max_token_limit)[:100] def _target_length(self, fileid, line): transcript_path = os.path.join(self.base_dataset._path, "stm", fileid) with open(transcript_path + ".stm") as f: transcript = f.readlines()[line] - talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6) + _, _, _, start_time, end_time, _, transcript = transcript.split(" ", 6) if transcript.lower() == "ignore_time_segment_in_scoring\n": return -1 else: @@ -72,72 +56,31 @@ def __len__(self): return len(self.batches) -class FunctionalModule(torch.nn.Module): - def __init__(self, functional): +class EvalDataset(torch.utils.data.IterableDataset): + def __init__(self, base_dataset): super().__init__() - self.functional = functional - - def forward(self, input): - return self.functional(input) - - -class GlobalStatsNormalization(torch.nn.Module): - def __init__(self, global_stats_path): - super().__init__() - - with open(global_stats_path) as f: - blob = json.loads(f.read()) - - self.mean = torch.tensor(blob["mean"]) - self.invstddev = torch.tensor(blob["invstddev"]) - - def forward(self, input): - return (input - self.mean) * self.invstddev - - -class WarmupLR(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False): - self.warmup_updates = warmup_updates - super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) - - def get_lr(self): - return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs] - - -def post_process_hypos( - hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor -) -> List[Tuple[str, float, List[int], List[int]]]: - post_process_remove_list = [ - sp_model.unk_id(), - sp_model.eos_id(), - sp_model.pad_id(), - ] - filtered_hypo_tokens = [ - [token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos - ] - hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] - hypos_ali = [h.alignment[1:] for h in hypos] - hypos_ids = [h.tokens[1:] for h in hypos] - hypos_score = [[math.exp(h.score)] for h in hypos] - - nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids)) + self.base_dataset = base_dataset - return nbest_batch + def __iter__(self): + for sample in iter(self.base_dataset): + actual = sample[2].replace("\n", "") + if actual == "ignore_time_segment_in_scoring": + continue + yield sample -class RNNTModule(LightningModule): +class TEDLIUM3RNNTModule(LightningModule): def __init__( self, *, tedlium_path: str, sp_model_path: str, global_stats_path: str, - reduction: str, ): super().__init__() self.model = emformer_rnnt_base(num_symbols=501) - self.loss = torchaudio.transforms.RNNTLoss(reduction=reduction, clamp=1.0) + self.loss = torchaudio.transforms.RNNTLoss(reduction="mean", clamp=1.0) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000) @@ -147,8 +90,8 @@ def __init__( FunctionalModule(lambda x: x.transpose(1, 2)), torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27), - TimeMasking(100, p=0.2), - TimeMasking(100, p=0.2), + torchaudio.transforms.TimeMasking(100, p=0.2), + torchaudio.transforms.TimeMasking(100, p=0.2), FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(lambda x: x.transpose(1, 2)), ) @@ -216,7 +159,7 @@ def _valid_collate_fn(self, samples: List): return Batch(features, feature_lengths, targets, target_lengths) def _test_collate_fn(self, samples: List): - return self._valid_collate_fn(samples), samples + return self._valid_collate_fn(samples), [sample[2] for sample in samples] def _step(self, batch, batch_idx, step_type): if batch is None: @@ -280,11 +223,11 @@ def val_dataloader(self): return dataloader def test_dataloader(self): - dataset = torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="test") + dataset = EvalDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="test")) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn) return dataloader def dev_dataloader(self): - dataset = torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev") + dataset = EvalDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev")) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn) return dataloader diff --git a/examples/asr/tedlium3_emformer_rnnt/train.py b/examples/asr/emformer_rnnt/train.py similarity index 57% rename from examples/asr/tedlium3_emformer_rnnt/train.py rename to examples/asr/emformer_rnnt/train.py index c190710759..57feee195b 100644 --- a/examples/asr/tedlium3_emformer_rnnt/train.py +++ b/examples/asr/emformer_rnnt/train.py @@ -1,13 +1,15 @@ import logging import pathlib -from argparse import ArgumentParser, RawTextHelpFormatter +from argparse import ArgumentParser -from lightning import RNNTModule +from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3 +from librispeech.lightning import LibriSpeechRNNTModule from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from tedlium3.lightning import TEDLIUM3RNNTModule -def run_train(args): +def get_trainer(args): checkpoint_dir = args.exp_dir / "checkpoints" checkpoint = ModelCheckpoint( checkpoint_dir, @@ -29,65 +31,68 @@ def run_train(args): checkpoint, train_checkpoint, ] - trainer = Trainer( + return Trainer( default_root_dir=args.exp_dir, max_epochs=args.epochs, num_nodes=args.num_nodes, gpus=args.gpus, accelerator="gpu", strategy="ddp", - gradient_clip_val=5.0, + gradient_clip_val=args.gradient_clip_val, callbacks=callbacks, ) - model = RNNTModule( - tedlium_path=str(args.tedlium_path), - sp_model_path=str(args.sp_model_path), - global_stats_path=str(args.global_stats_path), - reduction=args.reduction, - ) - trainer.fit(model) +def get_lightning_module(args): + if args.model_type == MODEL_TYPE_LIBRISPEECH: + return LibriSpeechRNNTModule( + librispeech_path=str(args.dataset_path), + sp_model_path=str(args.sp_model_path), + global_stats_path=str(args.global_stats_path), + ) + elif args.model_type == MODEL_TYPE_TEDLIUM3: + return TEDLIUM3RNNTModule( + tedlium_path=str(args.dataset_path), + sp_model_path=str(args.sp_model_path), + global_stats_path=str(args.global_stats_path), + ) + else: + raise ValueError(f"Encountered unsupported model type {args.model_type}.") -def _parse_args(): - parser = ArgumentParser( - description=__doc__, - formatter_class=RawTextHelpFormatter, - ) - parser.add_argument( - "--exp-dir", - default=pathlib.Path("./exp"), - type=pathlib.Path, - help="Directory to save checkpoints and logs to. (Default: './exp')", - ) + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True) parser.add_argument( - "--global-stats-path", + "--global_stats_path", default=pathlib.Path("global_stats.json"), type=pathlib.Path, help="Path to JSON file containing feature means and stddevs.", + required=True, ) parser.add_argument( - "--tedlium-path", + "--dataset_path", type=pathlib.Path, + help="Path to datasets.", required=True, - help="Path to TED-LIUM release 3 dataset.", ) parser.add_argument( - "--reduction", - default="mean", - type=str, - help="Reduction option for RNN Transducer loss function." "(Default: ``mean``)", + "--sp_model_path", + type=pathlib.Path, + help="Path to SentencePiece model.", + required=True, ) parser.add_argument( - "--sp-model-path", + "--exp_dir", + default=pathlib.Path("./exp"), type=pathlib.Path, - help="Path to SentencePiece model.", + help="Directory to save checkpoints and logs to. (Default: './exp')", ) parser.add_argument( - "--num-nodes", - default=1, + "--num_nodes", + default=4, type=int, - help="Number of nodes to use for training. (Default: 1)", + help="Number of nodes to use for training. (Default: 4)", ) parser.add_argument( "--gpus", @@ -101,20 +106,25 @@ def _parse_args(): type=int, help="Number of epochs to train for. (Default: 120)", ) + parser.add_argument( + "--gradient_clip_val", default=10.0, type=float, help="Value to clip gradient values to. (Default: 10.0)" + ) parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") return parser.parse_args() -def _init_logger(debug): +def init_logger(debug): fmt = "%(asctime)s %(message)s" if debug else "%(message)s" level = logging.DEBUG if debug else logging.INFO logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") def cli_main(): - args = _parse_args() - _init_logger(args.debug) - run_train(args) + args = parse_args() + init_logger(args.debug) + model = get_lightning_module(args) + trainer = get_trainer(args) + trainer.fit(model) if __name__ == "__main__": diff --git a/examples/asr/librispeech_emformer_rnnt/README.md b/examples/asr/librispeech_emformer_rnnt/README.md deleted file mode 100644 index 63746b5fda..0000000000 --- a/examples/asr/librispeech_emformer_rnnt/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# Emformer RNN-T ASR Example - -This directory contains sample implementations of training and evaluation pipelines for an on-device-oriented streaming-capable Emformer RNN-T ASR model. - -## Usage - -### Training - -[`train.py`](./train.py) trains an Emformer RNN-T model on LibriSpeech using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full LibriSpeech dataset and the SentencePiece model to be used to encode targets. The script also expects a file (--global_stats_path) that contains training set feature statistics; this file can be generated via [`global_stats.py`](./global_stats.py). - -Sample SLURM command: -``` -srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp_dir ./experiments --librispeech_path ./librispeech/ --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_4096.model -``` - -### Evaluation - -[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on LibriSpeech test-clean. - -Using the default configuration along with a SentencePiece model trained on LibriSpeech with vocab size 4096 and type bpe, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0466 when evaluated on test-clean with [`eval.py`](./eval.py). - -The table below contains WER results for various splits. - -| | WER | -|:-------------------:|-------------:| -| test-clean | 0.0456 | -| test-other | 0.1066 | -| dev-clean | 0.0415 | -| dev-other | 0.1110 | - - -Sample SLURM command: -``` -srun python eval.py --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --librispeech_path ./librispeech/ --sp_model_path ./spm_bpe_4096.model --use_cuda -``` diff --git a/examples/asr/librispeech_emformer_rnnt/eval.py b/examples/asr/librispeech_emformer_rnnt/eval.py deleted file mode 100644 index 2cae5de8b6..0000000000 --- a/examples/asr/librispeech_emformer_rnnt/eval.py +++ /dev/null @@ -1,76 +0,0 @@ -import logging -import pathlib -from argparse import ArgumentParser - -import torch -import torchaudio -from lightning import RNNTModule - - -logger = logging.getLogger() - - -def compute_word_level_distance(seq1, seq2): - return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split()) - - -def run_eval(args): - model = RNNTModule.load_from_checkpoint( - args.checkpoint_path, - librispeech_path=str(args.librispeech_path), - sp_model_path=str(args.sp_model_path), - global_stats_path=str(args.global_stats_path), - ).eval() - - if args.use_cuda: - model = model.to(device="cuda") - - total_edit_distance = 0 - total_length = 0 - dataloader = model.test_dataloader() - with torch.no_grad(): - for idx, (batch, sample) in enumerate(dataloader): - actual = sample[0][2] - predicted = model(batch) - total_edit_distance += compute_word_level_distance(actual, predicted) - total_length += len(actual.split()) - if idx % 100 == 0: - logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}") - logger.info(f"Final WER: {total_edit_distance / total_length}") - - -def cli_main(): - parser = ArgumentParser() - parser.add_argument( - "--checkpoint_path", - type=pathlib.Path, - help="Path to checkpoint to use for evaluation.", - ) - parser.add_argument( - "--global_stats_path", - default=pathlib.Path("global_stats.json"), - type=pathlib.Path, - help="Path to JSON file containing feature means and stddevs.", - ) - parser.add_argument( - "--librispeech_path", - type=pathlib.Path, - help="Path to LibriSpeech datasets.", - ) - parser.add_argument( - "--sp_model_path", - type=pathlib.Path, - help="Path to SentencePiece model.", - ) - parser.add_argument( - "--use_cuda", - action="store_true", - default=False, - help="Run using CUDA.", - ) - args = parser.parse_args() - run_eval(args) - - -if __name__ == "__main__": - cli_main() diff --git a/examples/asr/librispeech_emformer_rnnt/train.py b/examples/asr/librispeech_emformer_rnnt/train.py deleted file mode 100644 index e18ba21f6e..0000000000 --- a/examples/asr/librispeech_emformer_rnnt/train.py +++ /dev/null @@ -1,98 +0,0 @@ -import pathlib -from argparse import ArgumentParser - -from lightning import RNNTModule -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint - - -def run_train(args): - checkpoint_dir = args.exp_dir / "checkpoints" - checkpoint = ModelCheckpoint( - checkpoint_dir, - monitor="Losses/val_loss", - mode="min", - save_top_k=5, - save_weights_only=True, - verbose=True, - ) - train_checkpoint = ModelCheckpoint( - checkpoint_dir, - monitor="Losses/train_loss", - mode="min", - save_top_k=5, - save_weights_only=True, - verbose=True, - ) - callbacks = [ - checkpoint, - train_checkpoint, - ] - trainer = Trainer( - default_root_dir=args.exp_dir, - max_epochs=args.epochs, - num_nodes=args.num_nodes, - gpus=args.gpus, - accelerator="gpu", - strategy="ddp", - gradient_clip_val=10.0, - callbacks=callbacks, - ) - - model = RNNTModule( - librispeech_path=str(args.librispeech_path), - sp_model_path=str(args.sp_model_path), - global_stats_path=str(args.global_stats_path), - ) - trainer.fit(model) - - -def cli_main(): - parser = ArgumentParser() - parser.add_argument( - "--exp_dir", - default=pathlib.Path("./exp"), - type=pathlib.Path, - help="Directory to save checkpoints and logs to. (Default: './exp')", - ) - parser.add_argument( - "--global_stats_path", - default=pathlib.Path("global_stats.json"), - type=pathlib.Path, - help="Path to JSON file containing feature means and stddevs.", - ) - parser.add_argument( - "--librispeech_path", - type=pathlib.Path, - help="Path to LibriSpeech datasets.", - ) - parser.add_argument( - "--sp_model_path", - type=pathlib.Path, - help="Path to SentencePiece model.", - ) - parser.add_argument( - "--num_nodes", - default=4, - type=int, - help="Number of nodes to use for training. (Default: 4)", - ) - parser.add_argument( - "--gpus", - default=8, - type=int, - help="Number of GPUs per node to use for training. (Default: 8)", - ) - parser.add_argument( - "--epochs", - default=120, - type=int, - help="Number of epochs to train for. (Default: 120)", - ) - args = parser.parse_args() - - run_train(args) - - -if __name__ == "__main__": - cli_main() diff --git a/examples/asr/librispeech_emformer_rnnt/utils.py b/examples/asr/librispeech_emformer_rnnt/utils.py deleted file mode 100644 index 2a763d4d1c..0000000000 --- a/examples/asr/librispeech_emformer_rnnt/utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import math - -import torch -import torchaudio - - -DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) -GAIN = pow(10, 0.05 * DECIBEL) -spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160) - - -def piecewise_linear_log(x): - x[x > math.e] = torch.log(x[x > math.e]) - x[x <= math.e] = x[x <= math.e] / math.e - return x diff --git a/examples/asr/tedlium3_emformer_rnnt/README.md b/examples/asr/tedlium3_emformer_rnnt/README.md deleted file mode 100644 index 78b4b719ae..0000000000 --- a/examples/asr/tedlium3_emformer_rnnt/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# Emformer RNN-T ASR Example for TED-LIUM release 3 dataset - -This directory contains sample implementations of training and evaluation pipelines for an on-device-oriented streaming-capable Emformer RNN-T ASR model. - -## Usage - -### Training - -[`train.py`](./train.py) trains an Emformer RNN-T model on TED-LIUM release 3 using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full TED-LIUM release 3 dataset and the SentencePiece model to be used to encode targets. - -Sample SLURM command: -``` -srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --exp-dir ./experiments --tedlium-path ./datasets/ --global-stats-path ./global_stats.json --sp-model-path ./spm_bpe_500.model -``` - -### Evaluation - -[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on TED-LIUM release 3 test set. - -The table below contains WER results for dev and test subsets of TED-LIUM release 3. - -| | WER | -|:-----------:|-------------:| -| dev | 0.108 | -| test | 0.098 | - - -Sample SLURM command: -``` -srun python eval.py --checkpoint-path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium-path ./datasets/ --sp-model-path ./spm-bpe-500.model --use-cuda -``` - -### Evaluation using `torchaudio.pipelines.EMFORMER_RNNT_BASE_TEDLIUM3` bundle - -[`eval_pipeline.py`](./eval_pipeline.py) evaluates the `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3. -You should be able to get identical WER results in the above table. - - -Sample SLURM command: -``` -srun python eval_pipeline.py --tedlium-path ./datasets/ --use-cuda -``` diff --git a/examples/asr/tedlium3_emformer_rnnt/compute_global_stats.py b/examples/asr/tedlium3_emformer_rnnt/compute_global_stats.py deleted file mode 100644 index accb41971b..0000000000 --- a/examples/asr/tedlium3_emformer_rnnt/compute_global_stats.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Generate feature statistics for TED-LIUM release 3 training set. -Example: -python compute_global_stats.py --tedlium-path /home/datasets/ -""" - -import json -import logging -import pathlib -from argparse import ArgumentParser, RawTextHelpFormatter - -import torchaudio -from utils import GAIN, piecewise_linear_log, spectrogram_transform - -logger = logging.getLogger(__name__) - - -def _parse_args(): - parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) - parser.add_argument( - "--tedlium-path", - required=True, - type=pathlib.Path, - help="Path to TED-LIUM release 3 dataset.", - ) - parser.add_argument( - "--output-path", - default=pathlib.Path("global_stats.json"), - type=pathlib.Path, - help="File to save feature statistics to. (Default: './global_stats.json')", - ) - parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") - return parser.parse_args() - - -def _compute_stats(dataset): - E_x = 0.0 - E_x_2 = 0.0 - N = 0.0 - for idx, data in enumerate(dataset): - waveform = data[0].squeeze() - mel_spec = spectrogram_transform(waveform) - scaled_mel_spec = piecewise_linear_log(mel_spec * GAIN) - mel_sum = scaled_mel_spec.sum(-1) - mel_sum_sq = scaled_mel_spec.pow(2).sum(-1) - M = scaled_mel_spec.size(1) - - E_x = E_x * (N / (N + M)) + mel_sum / (N + M) - E_x_2 = E_x_2 * (N / (N + M)) + mel_sum_sq / (N + M) - N += M - - if idx % 100 == 0: - logger.info(f"Processed {idx}") - - return E_x, (E_x_2 - E_x ** 2) ** 0.5 - - -def _init_logger(debug): - fmt = "%(asctime)s %(message)s" if debug else "%(message)s" - level = logging.DEBUG if debug else logging.INFO - logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") - - -def cli_main(): - args = _parse_args() - _init_logger(args.debug) - dataset = torchaudio.datasets.TEDLIUM(args.tedlium_path, release="release3", subset="train") - mean, std = _compute_stats(dataset) - invstd = 1 / std - - stats_dict = { - "mean": mean.tolist(), - "invstddev": invstd.tolist(), - } - - with open(args.output_path, "w") as f: - json.dump(stats_dict, f, indent=2) - - -if __name__ == "__main__": - cli_main() diff --git a/examples/asr/tedlium3_emformer_rnnt/eval.py b/examples/asr/tedlium3_emformer_rnnt/eval.py deleted file mode 100644 index 71cd5c5a38..0000000000 --- a/examples/asr/tedlium3_emformer_rnnt/eval.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging -import pathlib -from argparse import ArgumentParser, RawTextHelpFormatter - -import torch -import torchaudio -from lightning import RNNTModule - - -logger = logging.getLogger(__name__) - - -def compute_word_level_distance(seq1, seq2): - return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split()) - - -def _eval_subset(model, subset): - total_edit_distance = 0.0 - total_length = 0.0 - if subset == "dev": - dataloader = model.dev_dataloader() - else: - dataloader = model.test_dataloader() - with torch.no_grad(): - for idx, (batch, sample) in enumerate(dataloader): - actual = sample[0][2].replace("\n", "") - if actual == "ignore_time_segment_in_scoring": - continue - predicted = model(batch) - total_edit_distance += compute_word_level_distance(actual, predicted) - total_length += len(actual.split()) - if idx % 100 == 0: - logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}") - logger.info(f"Final WER for {subset} set: {total_edit_distance / total_length}") - - -def run_eval(args): - model = RNNTModule.load_from_checkpoint( - args.checkpoint_path, - tedlium_path=str(args.tedlium_path), - sp_model_path=str(args.sp_model_path), - global_stats_path=str(args.global_stats_path), - reduction="mean", - ).eval() - - if args.use_cuda: - model = model.to(device="cuda") - _eval_subset(model, "dev") - _eval_subset(model, "test") - - -def _parse_args(): - parser = ArgumentParser( - description=__doc__, - formatter_class=RawTextHelpFormatter, - ) - parser.add_argument( - "--checkpoint-path", - type=pathlib.Path, - help="Path to checkpoint to use for evaluation.", - ) - parser.add_argument( - "--global-stats-path", - default=pathlib.Path("global_stats.json"), - type=pathlib.Path, - help="Path to JSON file containing feature means and stddevs.", - ) - parser.add_argument( - "--tedlium-path", - type=pathlib.Path, - help="Path to TED-LIUM release 3 dataset.", - ) - parser.add_argument( - "--sp-model-path", - type=pathlib.Path, - help="Path to SentencePiece model.", - ) - parser.add_argument( - "--use-cuda", - action="store_true", - default=False, - help="Run using CUDA.", - ) - parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") - return parser.parse_args() - - -def _init_logger(debug): - fmt = "%(asctime)s %(message)s" if debug else "%(message)s" - level = logging.DEBUG if debug else logging.INFO - logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") - - -def cli_main(): - args = _parse_args() - _init_logger(args.debug) - run_eval(args) - - -if __name__ == "__main__": - cli_main() diff --git a/examples/asr/tedlium3_emformer_rnnt/train_spm.py b/examples/asr/tedlium3_emformer_rnnt/train_spm.py deleted file mode 100644 index 5a5115f9c6..0000000000 --- a/examples/asr/tedlium3_emformer_rnnt/train_spm.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Train the SentencePiece model by using the transcripts of TED-LIUM release 3 training set. -Example: -python train_spm.py --tedlium-path /home/datasets/ -""" - -import logging -import os -import pathlib -from argparse import ArgumentParser, RawTextHelpFormatter - -import sentencepiece as spm - -logger = logging.getLogger(__name__) - - -def _parse_args(): - parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) - parser.add_argument( - "--tedlium-path", - required=True, - type=pathlib.Path, - help="Path to TED-LIUM release 3 dataset.", - ) - parser.add_argument( - "--output-dir", - default=pathlib.Path("./"), - type=pathlib.Path, - help="File to save feature statistics to. (Default: './')", - ) - parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") - return parser.parse_args() - - -def _extract_train_text(tedlium_path, output_dir): - stm_path = tedlium_path / "TEDLIUM_release-3/data/stm/" - transcripts = [] - for file in sorted(os.listdir(stm_path)): - if file.endswith(".stm"): - file = os.path.join(stm_path, file) - with open(file) as f: - for line in f.readlines(): - talk_id, _, speaker_id, start_time, end_time, identifier, transcript = line.split(" ", 6) - if transcript == "ignore_time_segment_in_scoring\n": - continue - else: - transcript = transcript.lower().replace("", "") - transcripts.append(transcript) - - with open(output_dir / "text_train.txt", "w") as f: - f.writelines(transcripts) - - -def _init_logger(debug): - fmt = "%(asctime)s %(message)s" if debug else "%(message)s" - level = logging.DEBUG if debug else logging.INFO - logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") - - -def cli_main(): - args = _parse_args() - _init_logger(args.debug) - _extract_train_text(args.tedlium_path, args.output_dir) - - spm.SentencePieceTrainer.train( - input=args.output_dir / "text_train.txt", - vocab_size=500, - model_prefix="spm_bpe_500", - model_type="bpe", - input_sentence_size=100000000, - character_coverage=1.0, - user_defined_symbols=[""], - bos_id=0, - pad_id=1, - eos_id=2, - unk_id=3, - ) - logger.info("Successfully trained the sentencepiece model") - - -if __name__ == "__main__": - cli_main() diff --git a/examples/asr/tedlium3_emformer_rnnt/utils.py b/examples/asr/tedlium3_emformer_rnnt/utils.py deleted file mode 120000 index a141b8a58e..0000000000 --- a/examples/asr/tedlium3_emformer_rnnt/utils.py +++ /dev/null @@ -1 +0,0 @@ -../librispeech_emformer_rnnt/utils.py \ No newline at end of file