Skip to content

Commit

Permalink
Add SentencePiece model training script for LibriSpeech Emformer RNN-T
Browse files Browse the repository at this point in the history
  • Loading branch information
hwangjeff committed Feb 11, 2022
1 parent bbdbd58 commit ffaaa6f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
4 changes: 3 additions & 1 deletion examples/asr/emformer_rnnt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ Sample SLURM command for evaluation:
srun python eval.py --model_type librispeech --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --dataset_path ./datasets/librispeech --sp_model_path ./spm_bpe_4096.model --use_cuda
```

Using the sample training command above along with a SentencePiece model trained on LibriSpeech with vocab size 4096 and type bpe, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0456 when evaluated on test-clean with [`eval.py`](./eval.py).
The script used for training the SentencePiece model that's referenced by the training command above can be found at [`librispeech/train_spm.py`](./librispeech/train_spm.py); a pretrained SentencePiece model can be downloaded [here](https://download.pytorch.org/torchaudio/pipeline-assets/spm_bpe_4096_librispeech.model).

Using the sample training command above, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0456 when evaluated on test-clean with [`eval.py`](./eval.py).

The table below contains WER results for various splits.

Expand Down
78 changes: 78 additions & 0 deletions examples/asr/emformer_rnnt/librispeech/train_spm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500.
Example:
python train_spm.py --librispeech-path ./datasets
"""

import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter

import sentencepiece as spm


def get_transcript_text(transcript_path):
with open(transcript_path) as f:
return [line.strip().split(" ", 1)[1].lower() for line in f]


def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*/*.trans.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts


def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=4096,
model_type="bpe",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()


def parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--librispeech-path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path("./spm_bpe_4096.model"),
type=pathlib.Path,
help="File to save model to. (Default: './spm_bpe_4096.model')",
)
return parser.parse_args()


def run_cli():
args = parse_args()

root = args.librispeech_path / "LibriSpeech"
splits = ["train-clean-100", "train-clean-360", "train-other-500"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)

model = train_spm(merged_transcripts)

with open(args.output_file, "wb") as f:
f.write(model)


if __name__ == "__main__":
run_cli()

0 comments on commit ffaaa6f

Please sign in to comment.