-
Notifications
You must be signed in to change notification settings - Fork 684
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Emformer RNNT recipes (#2212)
Summary: Consolidates LibriSpeech and TED-LIUM Release 3 Emformer RNN-T training recipes in a single directory. Pull Request resolved: #2212 Reviewed By: mthrok Differential Revision: D34120104 Pulled By: hwangjeff fbshipit-source-id: 9ed0a5d5b209478a841324f360c5b66268e3b228
- Loading branch information
1 parent
87d7694
commit ebda0ec
Showing
20 changed files
with
390 additions
and
779 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Emformer RNN-T ASR Example | ||
|
||
This directory contains sample implementations of training and evaluation pipelines for an Emformer RNN-T streaming ASR model. | ||
|
||
## Usage | ||
|
||
### Training | ||
|
||
[`train.py`](./train.py) trains an Emformer RNN-T model using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to datasets and the SentencePiece model to be used to encode targets. The script also expects a file (--global_stats_path) that contains training set feature statistics; this file can be generated via [`global_stats.py`](./global_stats.py). | ||
|
||
### Evaluation | ||
|
||
[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on a given dataset. | ||
|
||
## Model Types | ||
|
||
Currently, we have training recipes for the LibriSpeech and TED-LIUM Release 3 datasets. | ||
|
||
### LibriSpeech | ||
|
||
Sample SLURM command for training: | ||
``` | ||
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --model_type librispeech --exp_dir ./experiments --dataset_path ./datasets/librispeech --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_4096.model | ||
``` | ||
|
||
Sample SLURM command for evaluation: | ||
``` | ||
srun python eval.py --model_type librispeech --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --dataset_path ./datasets/librispeech --sp_model_path ./spm_bpe_4096.model --use_cuda | ||
``` | ||
|
||
Using the sample training command above along with a SentencePiece model trained on LibriSpeech with vocab size 4096 and type bpe, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0456 when evaluated on test-clean with [`eval.py`](./eval.py). | ||
|
||
The table below contains WER results for various splits. | ||
|
||
| | WER | | ||
|:-------------------:|-------------:| | ||
| test-clean | 0.0456 | | ||
| test-other | 0.1066 | | ||
| dev-clean | 0.0415 | | ||
| dev-other | 0.1110 | | ||
|
||
[`librispeech/pipeline_demo.py`](./librispeech/pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH` bundle that wraps a pre-trained Emformer RNN-T produced by the above recipe to perform streaming and full-context ASR on several LibriSpeech samples. | ||
|
||
### TED-LIUM Release 3 | ||
|
||
Whereas the LibriSpeech model is configured with a vocabulary size of 4096, the TED-LIUM Release 3 model is configured with a vocabulary size of 500. Consequently, the TED-LIUM Release 3 model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol); the rest of the model is identical to the LibriSpeech model. | ||
|
||
Sample SLURM command for training: | ||
``` | ||
srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --model_type tedlium3 --exp_dir ./experiments --dataset_path ./datasets/tedlium --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_500.model --gradient_clip_val 5.0 | ||
``` | ||
|
||
Sample SLURM command for evaluation: | ||
``` | ||
srun python eval.py --model_type tedlium3 --checkpoint_path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium_path ./datasets/tedlium --sp_model_path ./spm-bpe-500.model --use_cuda | ||
``` | ||
|
||
The table below contains WER results for dev and test subsets of TED-LIUM release 3. | ||
|
||
| | WER | | ||
|:-----------:|-------------:| | ||
| dev | 0.108 | | ||
| test | 0.098 | | ||
|
||
[`tedlium3/eval_pipeline.py`](./tedlium3/eval_pipeline.py) evaluates the pre-trained `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3. Running the script should produce WER results that are identical to those in the above table. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import json | ||
import math | ||
from collections import namedtuple | ||
from typing import List, Tuple | ||
|
||
import sentencepiece as spm | ||
import torch | ||
import torchaudio | ||
from torchaudio.models import Hypothesis | ||
|
||
|
||
MODEL_TYPE_LIBRISPEECH = "librispeech" | ||
MODEL_TYPE_TEDLIUM3 = "tedlium3" | ||
|
||
|
||
DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) | ||
GAIN = pow(10, 0.05 * DECIBEL) | ||
spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160) | ||
|
||
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) | ||
|
||
|
||
def piecewise_linear_log(x): | ||
x[x > math.e] = torch.log(x[x > math.e]) | ||
x[x <= math.e] = x[x <= math.e] / math.e | ||
return x | ||
|
||
|
||
def batch_by_token_count(idx_target_lengths, token_limit): | ||
batches = [] | ||
current_batch = [] | ||
current_token_count = 0 | ||
for idx, target_length in idx_target_lengths: | ||
if current_token_count + target_length > token_limit: | ||
batches.append(current_batch) | ||
current_batch = [idx] | ||
current_token_count = target_length | ||
else: | ||
current_batch.append(idx) | ||
current_token_count += target_length | ||
|
||
if current_batch: | ||
batches.append(current_batch) | ||
|
||
return batches | ||
|
||
|
||
def post_process_hypos( | ||
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor | ||
) -> List[Tuple[str, float, List[int], List[int]]]: | ||
post_process_remove_list = [ | ||
sp_model.unk_id(), | ||
sp_model.eos_id(), | ||
sp_model.pad_id(), | ||
] | ||
filtered_hypo_tokens = [ | ||
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos | ||
] | ||
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] | ||
hypos_ali = [h.alignment[1:] for h in hypos] | ||
hypos_ids = [h.tokens[1:] for h in hypos] | ||
hypos_score = [[math.exp(h.score)] for h in hypos] | ||
|
||
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids)) | ||
|
||
return nbest_batch | ||
|
||
|
||
class FunctionalModule(torch.nn.Module): | ||
def __init__(self, functional): | ||
super().__init__() | ||
self.functional = functional | ||
|
||
def forward(self, input): | ||
return self.functional(input) | ||
|
||
|
||
class GlobalStatsNormalization(torch.nn.Module): | ||
def __init__(self, global_stats_path): | ||
super().__init__() | ||
|
||
with open(global_stats_path) as f: | ||
blob = json.loads(f.read()) | ||
|
||
self.mean = torch.tensor(blob["mean"]) | ||
self.invstddev = torch.tensor(blob["invstddev"]) | ||
|
||
def forward(self, input): | ||
return (input - self.mean) * self.invstddev | ||
|
||
|
||
class WarmupLR(torch.optim.lr_scheduler._LRScheduler): | ||
def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False): | ||
self.warmup_updates = warmup_updates | ||
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) | ||
|
||
def get_lr(self): | ||
return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import logging | ||
import pathlib | ||
from argparse import ArgumentParser | ||
|
||
import torch | ||
import torchaudio | ||
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3 | ||
from librispeech.lightning import LibriSpeechRNNTModule | ||
from tedlium3.lightning import TEDLIUM3RNNTModule | ||
|
||
|
||
logger = logging.getLogger() | ||
|
||
|
||
def compute_word_level_distance(seq1, seq2): | ||
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split()) | ||
|
||
|
||
def run_eval(model): | ||
total_edit_distance = 0 | ||
total_length = 0 | ||
dataloader = model.test_dataloader() | ||
with torch.no_grad(): | ||
for idx, (batch, transcripts) in enumerate(dataloader): | ||
actual = transcripts[0] | ||
predicted = model(batch) | ||
total_edit_distance += compute_word_level_distance(actual, predicted) | ||
total_length += len(actual.split()) | ||
if idx % 100 == 0: | ||
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}") | ||
logger.info(f"Final WER: {total_edit_distance / total_length}") | ||
|
||
|
||
def get_lightning_module(args): | ||
if args.model_type == MODEL_TYPE_LIBRISPEECH: | ||
return LibriSpeechRNNTModule.load_from_checkpoint( | ||
args.checkpoint_path, | ||
librispeech_path=str(args.dataset_path), | ||
sp_model_path=str(args.sp_model_path), | ||
global_stats_path=str(args.global_stats_path), | ||
) | ||
elif args.model_type == MODEL_TYPE_TEDLIUM3: | ||
return TEDLIUM3RNNTModule.load_from_checkpoint( | ||
args.checkpoint_path, | ||
tedlium_path=str(args.dataset_path), | ||
sp_model_path=str(args.sp_model_path), | ||
global_stats_path=str(args.global_stats_path), | ||
) | ||
else: | ||
raise ValueError(f"Encountered unsupported model type {args.model_type}.") | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True) | ||
parser.add_argument( | ||
"--checkpoint_path", | ||
type=pathlib.Path, | ||
help="Path to checkpoint to use for evaluation.", | ||
) | ||
parser.add_argument( | ||
"--global_stats_path", | ||
default=pathlib.Path("global_stats.json"), | ||
type=pathlib.Path, | ||
help="Path to JSON file containing feature means and stddevs.", | ||
) | ||
parser.add_argument( | ||
"--dataset_path", | ||
type=pathlib.Path, | ||
help="Path to dataset.", | ||
) | ||
parser.add_argument( | ||
"--sp_model_path", | ||
type=pathlib.Path, | ||
help="Path to SentencePiece model.", | ||
) | ||
parser.add_argument( | ||
"--use_cuda", | ||
action="store_true", | ||
default=False, | ||
help="Run using CUDA.", | ||
) | ||
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") | ||
return parser.parse_args() | ||
|
||
|
||
def init_logger(debug): | ||
fmt = "%(asctime)s %(message)s" if debug else "%(message)s" | ||
level = logging.DEBUG if debug else logging.INFO | ||
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") | ||
|
||
|
||
def cli_main(): | ||
args = parse_args() | ||
init_logger(args.debug) | ||
model = get_lightning_module(args) | ||
if args.use_cuda: | ||
model = model.to(device="cuda") | ||
run_eval(model) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.