Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor LibriSpeech Conformer RNN-T recipe #2366

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions examples/asr/librispeech_conformer_rnnt/data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import torch
import torchaudio
from pytorch_lightning import LightningDataModule, seed_everything

import os
import random


seed_everything(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right place to seed? I would imagine this happen once (and only once) at the very beginning of the CLI entry point.



def _batch_by_token_count(idx_target_lengths, token_limit, sample_limit=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change sample_limit to batch_size and token_limit to max_tokens per our previous discussion? It's ok if you want to do that in another PR.

batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > token_limit or (sample_limit and len(current_batch) == sample_limit):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length

if current_batch:
batches.append(current_batch)

return batches


def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}

def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")

file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)

with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)

return fileid_to_target_length[fileid]

return [_target_length(fileid) for fileid in librispeech_dataset._walker]


class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(self, dataset, lengths, max_token_limit, num_buckets, shuffle=False, sample_limit=None):
super().__init__()

assert len(dataset) == len(lengths)

self.dataset = dataset

max_length = max(lengths)
min_length = min(lengths)

assert max_token_limit >= max_length

buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)

idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)

sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets], max_token_limit, sample_limit=sample_limit
)

def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]

def __len__(self):
return len(self.batches)


class TransformDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn

def __getitem__(self, idx):
return self.transform_fn(self.dataset[idx])

def __len__(self):
return len(self.dataset)


class LibriSpeechDataModule(LightningDataModule):
def __init__(
self,
*,
librispeech_path,
train_transform,
val_transform,
test_transform,
max_token_limit=700,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_tokens should be fine. "max" duplicates with "limit"

sample_limit=2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_size

train_num_buckets=50,
train_shuffle=True,
num_workers=10,
):
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
self.max_token_limit = max_token_limit
self.sample_limit = sample_limit
self.train_num_buckets = train_num_buckets
self.train_shuffle = train_shuffle
self.num_workers = num_workers

def train_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
]

if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]

dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset, lengths, self.max_token_limit, self.train_num_buckets, sample_limit=self.sample_limit,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset, num_workers=self.num_workers, batch_size=None, shuffle=self.train_shuffle
)
return dataloader

def val_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
]

if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]

dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(dataset, lengths, self.max_token_limit, 1, sample_limit=self.sample_limit)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader

def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better make the eval split customizable as we discussed. again I can do that in a later PR as well.

dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
29 changes: 8 additions & 21 deletions examples/asr/librispeech_conformer_rnnt/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torchaudio
from lightning import ConformerRNNTModule
from lightning import ConformerRNNTModule, get_data_module


logger = logging.getLogger()
Expand All @@ -15,19 +15,15 @@ def compute_word_level_distance(seq1, seq2):


def run_eval(args):
model = ConformerRNNTModule.load_from_checkpoint(
args.checkpoint_path,
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
).eval()
model = ConformerRNNTModule.load_from_checkpoint(args.checkpoint_path, sp_model_path=str(args.sp_model_path)).eval()
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path))

if args.use_cuda:
model = model.to(device="cuda")

total_edit_distance = 0
total_length = 0
dataloader = model.test_dataloader()
dataloader = data_module.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][2]
Expand All @@ -42,9 +38,7 @@ def run_eval(args):
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
"--checkpoint-path", type=pathlib.Path, help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--global-stats-path",
Expand All @@ -53,20 +47,13 @@ def cli_main():
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
"--librispeech-path", type=pathlib.Path, help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
"--sp-model-path", type=pathlib.Path, help="Path to SentencePiece model.",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
"--use-cuda", action="store_true", default=False, help="Run using CUDA.",
)
args = parser.parse_args()
run_eval(args)
Expand Down
Loading