diff --git a/examples/asr/librispeech_conformer_rnnt/README.md b/examples/asr/librispeech_conformer_rnnt/README.md index 2c20c11420b..02f154d6e13 100644 --- a/examples/asr/librispeech_conformer_rnnt/README.md +++ b/examples/asr/librispeech_conformer_rnnt/README.md @@ -2,11 +2,28 @@ This directory contains sample implementations of training and evaluation pipelines for a Conformer RNN-T ASR model. +## Setup +### Install PyTorch and TorchAudio nightly or from source +Because Conformer RNN-T is currently a prototype feature, you will need to either use the TorchAudio nightly build or build TorchAudio from source. Note also that GPU support is required for training. + +To install the nightly, follow the directions at . + +To build TorchAudio from source, refer to the [contributing guidelines](https://github.com/pytorch/audio/blob/main/CONTRIBUTING.md). + +### Install additional dependencies +```bash +pip install pytorch-lightning sentencepiece +``` + ## Usage ### Training -[`train.py`](./train.py) trains an Conformer RNN-T model (30.2M parameters, 121MB) on LibriSpeech using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full LibriSpeech dataset and the SentencePiece model to be used to encode targets. +[`train.py`](./train.py) trains an Conformer RNN-T model (30.2M parameters, 121MB) on LibriSpeech using PyTorch Lightning. Note that the script expects users to have the following: +- Access to GPU nodes for training. +- Full LibriSpeech dataset. +- SentencePiece model to be used to encode targets; the model can be generated using [`train_spm.py`](./train_spm.py) +- File (--global_stats_path) that contains training set feature statistics; this file can be generated using [`global_stats.py`](../emformer_rnnt/global_stats.py). Sample SLURM command: ``` diff --git a/examples/asr/librispeech_conformer_rnnt/eval.py b/examples/asr/librispeech_conformer_rnnt/eval.py index 9f36eaddc17..d21e772df8a 100644 --- a/examples/asr/librispeech_conformer_rnnt/eval.py +++ b/examples/asr/librispeech_conformer_rnnt/eval.py @@ -4,7 +4,7 @@ import torch import torchaudio -from lightning import RNNTModule +from lightning import ConformerRNNTModule logger = logging.getLogger() @@ -15,7 +15,7 @@ def compute_word_level_distance(seq1, seq2): def run_eval(args): - model = RNNTModule.load_from_checkpoint( + model = ConformerRNNTModule.load_from_checkpoint( args.checkpoint_path, librispeech_path=str(args.librispeech_path), sp_model_path=str(args.sp_model_path), diff --git a/examples/asr/librispeech_conformer_rnnt/lightning.py b/examples/asr/librispeech_conformer_rnnt/lightning.py index b583177fe63..fe077273ea2 100644 --- a/examples/asr/librispeech_conformer_rnnt/lightning.py +++ b/examples/asr/librispeech_conformer_rnnt/lightning.py @@ -185,7 +185,7 @@ def post_process_hypos( return nbest_batch -class RNNTModule(LightningModule): +class ConformerRNNTModule(LightningModule): def __init__( self, *, diff --git a/examples/asr/librispeech_conformer_rnnt/train.py b/examples/asr/librispeech_conformer_rnnt/train.py index 493f5ecfa60..ef70c7bb39a 100644 --- a/examples/asr/librispeech_conformer_rnnt/train.py +++ b/examples/asr/librispeech_conformer_rnnt/train.py @@ -1,7 +1,7 @@ import pathlib from argparse import ArgumentParser -from lightning import RNNTModule +from lightning import ConformerRNNTModule from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.plugins import DDPPlugin @@ -42,7 +42,7 @@ def run_train(args): reload_dataloaders_every_n_epochs=1, ) - model = RNNTModule( + model = ConformerRNNTModule( librispeech_path=str(args.librispeech_path), sp_model_path=str(args.sp_model_path), global_stats_path=str(args.global_stats_path), @@ -75,7 +75,7 @@ def cli_main(): help="Path to SentencePiece model.", ) parser.add_argument( - "--num-nodes", + "--nodes", default=4, type=int, help="Number of nodes to use for training. (Default: 4)", diff --git a/examples/asr/librispeech_conformer_rnnt/train_spm.py b/examples/asr/librispeech_conformer_rnnt/train_spm.py new file mode 100644 index 00000000000..75dba161c4e --- /dev/null +++ b/examples/asr/librispeech_conformer_rnnt/train_spm.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500. + +Example: +python train_spm.py --librispeech-path ./datasets +""" + +import io +import pathlib +from argparse import ArgumentParser, RawTextHelpFormatter + +import sentencepiece as spm + + +def get_transcript_text(transcript_path): + with open(transcript_path) as f: + return [line.strip().split(" ", 1)[1].lower() for line in f] + + +def get_transcripts(dataset_path): + transcript_paths = dataset_path.glob("*/*/*.trans.txt") + merged_transcripts = [] + for path in transcript_paths: + merged_transcripts += get_transcript_text(path) + return merged_transcripts + + +def train_spm(input): + model_writer = io.BytesIO() + spm.SentencePieceTrainer.train( + sentence_iterator=iter(input), + model_writer=model_writer, + vocab_size=1023, + model_type="unigram", + input_sentence_size=-1, + character_coverage=1.0, + bos_id=0, + pad_id=1, + eos_id=2, + unk_id=3, + ) + return model_writer.getvalue() + + +def parse_args(): + default_output_path = "./spm_unigram_1023.model" + parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) + parser.add_argument( + "--librispeech-path", + required=True, + type=pathlib.Path, + help="Path to LibriSpeech dataset.", + ) + parser.add_argument( + "--output-file", + default=pathlib.Path(default_output_path), + type=pathlib.Path, + help=f"File to save model to. (Default: '{default_output_path}')", + ) + return parser.parse_args() + + +def run_cli(): + args = parse_args() + + root = args.librispeech_path / "LibriSpeech" + splits = ["train-clean-100", "train-clean-360", "train-other-500"] + merged_transcripts = [] + for split in splits: + path = pathlib.Path(root) / split + merged_transcripts += get_transcripts(path) + + model = train_spm(merged_transcripts) + + with open(args.output_file, "wb") as f: + f.write(model) + + +if __name__ == "__main__": + run_cli()