Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
hwangjeff committed Apr 13, 2022
1 parent 30459cc commit 6fd974d
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 7 deletions.
19 changes: 18 additions & 1 deletion examples/asr/librispeech_conformer_rnnt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,28 @@

This directory contains sample implementations of training and evaluation pipelines for a Conformer RNN-T ASR model.

## Setup
### Install PyTorch and TorchAudio nightly or from source
Because Conformer RNN-T is currently a prototype feature, you will need to either use the TorchAudio nightly build or build TorchAudio from source. Note also that GPU support is required for training.

To install the nightly, follow the directions at <https://pytorch.org/>.

To build TorchAudio from source, refer to the [contributing guidelines](https://github.com/pytorch/audio/blob/main/CONTRIBUTING.md).

### Install additional dependencies
```bash
pip install pytorch-lightning sentencepiece
```

## Usage

### Training

[`train.py`](./train.py) trains an Conformer RNN-T model (30.2M parameters, 121MB) on LibriSpeech using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full LibriSpeech dataset and the SentencePiece model to be used to encode targets.
[`train.py`](./train.py) trains an Conformer RNN-T model (30.2M parameters, 121MB) on LibriSpeech using PyTorch Lightning. Note that the script expects users to have the following:
- Access to GPU nodes for training.
- Full LibriSpeech dataset.
- SentencePiece model to be used to encode targets; the model can be generated using [`train_spm.py`](./train_spm.py)
- File (--global_stats_path) that contains training set feature statistics; this file can be generated using [`global_stats.py`](../emformer_rnnt/global_stats.py).

Sample SLURM command:
```
Expand Down
4 changes: 2 additions & 2 deletions examples/asr/librispeech_conformer_rnnt/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torchaudio
from lightning import RNNTModule
from lightning import ConformerRNNTModule


logger = logging.getLogger()
Expand All @@ -15,7 +15,7 @@ def compute_word_level_distance(seq1, seq2):


def run_eval(args):
model = RNNTModule.load_from_checkpoint(
model = ConformerRNNTModule.load_from_checkpoint(
args.checkpoint_path,
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
Expand Down
2 changes: 1 addition & 1 deletion examples/asr/librispeech_conformer_rnnt/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def post_process_hypos(
return nbest_batch


class RNNTModule(LightningModule):
class ConformerRNNTModule(LightningModule):
def __init__(
self,
*,
Expand Down
6 changes: 3 additions & 3 deletions examples/asr/librispeech_conformer_rnnt/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib
from argparse import ArgumentParser

from lightning import RNNTModule
from lightning import ConformerRNNTModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.plugins import DDPPlugin
Expand Down Expand Up @@ -42,7 +42,7 @@ def run_train(args):
reload_dataloaders_every_n_epochs=1,
)

model = RNNTModule(
model = ConformerRNNTModule(
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
Expand Down Expand Up @@ -75,7 +75,7 @@ def cli_main():
help="Path to SentencePiece model.",
)
parser.add_argument(
"--num-nodes",
"--nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
Expand Down
80 changes: 80 additions & 0 deletions examples/asr/librispeech_conformer_rnnt/train_spm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500.
Example:
python train_spm.py --librispeech-path ./datasets
"""

import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter

import sentencepiece as spm


def get_transcript_text(transcript_path):
with open(transcript_path) as f:
return [line.strip().split(" ", 1)[1].lower() for line in f]


def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*/*.trans.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts


def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=1023,
model_type="unigram",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()


def parse_args():
default_output_path = "./spm_unigram_1023.model"
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--librispeech-path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path(default_output_path),
type=pathlib.Path,
help=f"File to save model to. (Default: '{default_output_path}')",
)
return parser.parse_args()


def run_cli():
args = parse_args()

root = args.librispeech_path / "LibriSpeech"
splits = ["train-clean-100", "train-clean-360", "train-other-500"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)

model = train_spm(merged_transcripts)

with open(args.output_file, "wb") as f:
f.write(model)


if __name__ == "__main__":
run_cli()

0 comments on commit 6fd974d

Please sign in to comment.