Skip to content

Commit

Permalink
Refactor Emformer RNNT recipes (#2212)
Browse files Browse the repository at this point in the history
Summary:
Consolidates LibriSpeech and TED-LIUM Release 3 Emformer RNN-T training recipes in a single directory.

Pull Request resolved: #2212

Differential Revision: D34120104

Pulled By: hwangjeff

fbshipit-source-id: 38cf6453d19e74f9851ff0544c6c7f604ec5b630
  • Loading branch information
hwangjeff authored and facebook-github-bot committed Feb 9, 2022
1 parent 87d7694 commit be42df8
Show file tree
Hide file tree
Showing 20 changed files with 392 additions and 779 deletions.
63 changes: 63 additions & 0 deletions examples/asr/emformer_rnnt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Emformer RNN-T ASR Example

This directory contains sample implementations of training and evaluation pipelines for an Emformer RNN-T streaming ASR model.

## Usage

### Training

[`train.py`](./train.py) trains an Emformer RNN-T model using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to datasets and the SentencePiece model to be used to encode targets. The script also expects a file (--global_stats_path) that contains training set feature statistics; this file can be generated via [`global_stats.py`](./global_stats.py).

### Evaluation

[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on a given dataset.

## Model Types

Currently, we have training recipes for the LibriSpeech and TED-LIUM Release 3 datasets.

### LibriSpeech

Sample SLURM command for training:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --model_type librispeech --exp_dir ./experiments --dataset_path ./datasets/librispeech --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_4096.model
```

Sample SLURM command for evaluation:
```
srun python eval.py --model_type librispeech --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --dataset_path ./datasets/librispeech --sp_model_path ./spm_bpe_4096.model --use_cuda
```

Using the sample training command above along with a SentencePiece model trained on LibriSpeech with vocab size 4096 and type bpe, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0456 when evaluated on test-clean with [`eval.py`](./eval.py).

The table below contains WER results for various splits.

| | WER |
|:-------------------:|-------------:|
| test-clean | 0.0456 |
| test-other | 0.1066 |
| dev-clean | 0.0415 |
| dev-other | 0.1110 |

[`librispeech/pipeline_demo.py`](./librispeech/pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH` bundle that wraps a pre-trained Emformer RNN-T produced by the above recipe to perform streaming and full-context ASR on several LibriSpeech samples.

### TED-LIUM Release 3

Sample SLURM command for training:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --model_type tedlium3 --exp_dir ./experiments --dataset_path ./datasets/tedlium --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_500.model --gradient_clip_val 5.0
```

Sample SLURM command for evaluation:
```
srun python eval.py --model_type tedlium3 --checkpoint_path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium_path ./datasets/tedlium --sp_model_path ./spm-bpe-500.model --use_cuda
```

The table below contains WER results for dev and test subsets of TED-LIUM release 3.

| | WER |
|:-----------:|-------------:|
| dev | 0.108 |
| test | 0.098 |

[`tedlium3/eval_pipeline.py`](./tedlium3/eval_pipeline.py) evaluates the pre-trained `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3. Running the script should produce WER results that are identical to those in the above table.
98 changes: 98 additions & 0 deletions examples/asr/emformer_rnnt/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json
import math
from collections import namedtuple
from typing import List, Tuple

import sentencepiece as spm
import torch
import torchaudio
from torchaudio.models import Hypothesis


MODEL_TYPE_LIBRISPEECH = "librispeech"
MODEL_TYPE_TEDLIUM3 = "tedlium3"


DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
GAIN = pow(10, 0.05 * DECIBEL)
spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)

Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])


def piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x


def batch_by_token_count(idx_target_lengths, token_limit):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > token_limit:
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length

if current_batch:
batches.append(current_batch)

return batches


def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]

nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))

return nbest_batch


class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional

def forward(self, input):
return self.functional(input)


class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()

with open(global_stats_path) as f:
blob = json.loads(f.read())

self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])

def forward(self, input):
return (input - self.mean) * self.invstddev


class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False):
self.warmup_updates = warmup_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)

def get_lr(self):
return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]
107 changes: 107 additions & 0 deletions examples/asr/emformer_rnnt/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import logging
import pathlib
from argparse import ArgumentParser

import torch
import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from librispeech.lightning import LibriSpeechRNNTModule
from tedlium3.lightning import TEDLIUM3RNNTModule


logger = logging.getLogger()


def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())


def run_eval(model):
total_edit_distance = 0
total_length = 0
dataloader = model.test_dataloader()
with torch.no_grad():
for idx, (batch, transcripts) in enumerate(dataloader):
actual = transcripts[0]
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.info(f"Final WER: {total_edit_distance / total_length}")


def get_lightning_module(args):
if args.model_type == MODEL_TYPE_LIBRISPEECH:
return LibriSpeechRNNTModule.load_from_checkpoint(
args.checkpoint_path,
librispeech_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
elif args.model_type == MODEL_TYPE_TEDLIUM3:
return TEDLIUM3RNNTModule.load_from_checkpoint(
args.checkpoint_path,
tedlium_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
else:
raise ValueError(f"Encountered unsupported model type {args.model_type}.")


def parse_args():
parser = ArgumentParser()
parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
parser.add_argument(
"--checkpoint_path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
required=True,
)
parser.add_argument(
"--global_stats_path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
required=True,
)
parser.add_argument(
"--dataset_path",
type=pathlib.Path,
help="Path to dataset.",
required=True,
)
parser.add_argument(
"--sp_model_path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--use_cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()


def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")


def cli_main():
args = parse_args()
init_logger(args.debug)
model = get_lightning_module(args)
if args.use_cuda:
model = model.to(device="cuda")
run_eval(model)


if __name__ == "__main__":
cli_main()
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Generate feature statistics for LibriSpeech training set.
"""Generate feature statistics for training set.
Example:
python global_stats.py --librispeech_path /home/librispeech
python global_stats.py --model_type librispeech --dataset_path /home/librispeech
"""

import json
Expand All @@ -11,19 +11,20 @@

import torch
import torchaudio
from utils import GAIN, piecewise_linear_log, spectrogram_transform
from common import GAIN, MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, piecewise_linear_log, spectrogram_transform

logger = logging.getLogger()


def parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
parser.add_argument(
"--librispeech_path",
"--dataset_path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech datasets. "
"All of 'train-clean-360', 'train-clean-100', and 'train-other-500' must exist.",
help="Path to dataset. "
"For LibriSpeech, all of 'train-clean-360', 'train-clean-100', and 'train-other-500' must exist.",
)
parser.add_argument(
"--output_path",
Expand Down Expand Up @@ -56,15 +57,24 @@ def generate_statistics(samples):
return E_x, (E_x_2 - E_x ** 2) ** 0.5


def get_dataset(args):
if args.model_type == MODEL_TYPE_LIBRISPEECH:
return torch.utils.data.ConcatDataset(
[
torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-other-500"),
]
)
elif args.model_type == MODEL_TYPE_TEDLIUM3:
return torchaudio.datasets.TEDLIUM(args.dataset_path, release="release3", subset="train")
else:
raise ValueError(f"Encountered unsupported model type {args.model_type}.")


def cli_main():
args = parse_args()
dataset = torch.utils.data.ConcatDataset(
[
torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-other-500"),
]
)
dataset = get_dataset(args)
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)
mean, stddev = generate_statistics(iter(dataloader))

Expand Down
Loading

0 comments on commit be42df8

Please sign in to comment.