From 7e55ba0b1c61e5d40c9cebb3b200f9b5a2cc7b51 Mon Sep 17 00:00:00 2001 From: hwangjeff Date: Fri, 6 May 2022 05:46:42 +0000 Subject: [PATCH 1/4] Refactor LibriSpeech Conformer RNN-T recipe --- .../librispeech_conformer_rnnt/data_module.py | 171 ++++++++++++ .../asr/librispeech_conformer_rnnt/eval.py | 29 +-- .../librispeech_conformer_rnnt/lightning.py | 244 ++---------------- .../asr/librispeech_conformer_rnnt/train.py | 48 +--- .../librispeech_conformer_rnnt/transforms.py | 104 ++++++++ 5 files changed, 309 insertions(+), 287 deletions(-) create mode 100644 examples/asr/librispeech_conformer_rnnt/data_module.py create mode 100644 examples/asr/librispeech_conformer_rnnt/transforms.py diff --git a/examples/asr/librispeech_conformer_rnnt/data_module.py b/examples/asr/librispeech_conformer_rnnt/data_module.py new file mode 100644 index 0000000000..291b036ab8 --- /dev/null +++ b/examples/asr/librispeech_conformer_rnnt/data_module.py @@ -0,0 +1,171 @@ +import torch +import torchaudio +from pytorch_lightning import LightningDataModule, seed_everything + +import os +import random + + +seed_everything(1) + + +def _batch_by_token_count(idx_target_lengths, token_limit, sample_limit=None): + batches = [] + current_batch = [] + current_token_count = 0 + for idx, target_length in idx_target_lengths: + if current_token_count + target_length > token_limit or (sample_limit and len(current_batch) == sample_limit): + batches.append(current_batch) + current_batch = [idx] + current_token_count = target_length + else: + current_batch.append(idx) + current_token_count += target_length + + if current_batch: + batches.append(current_batch) + + return batches + + +def get_sample_lengths(librispeech_dataset): + fileid_to_target_length = {} + + def _target_length(fileid): + if fileid not in fileid_to_target_length: + speaker_id, chapter_id, _ = fileid.split("-") + + file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt + file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text) + + with open(file_text) as ft: + for line in ft: + fileid_text, transcript = line.strip().split(" ", 1) + fileid_to_target_length[fileid_text] = len(transcript) + + return fileid_to_target_length[fileid] + + return [_target_length(fileid) for fileid in librispeech_dataset._walker] + + +class CustomBucketDataset(torch.utils.data.Dataset): + def __init__(self, dataset, lengths, max_token_limit, num_buckets, shuffle=False, sample_limit=None): + super().__init__() + + assert len(dataset) == len(lengths) + + self.dataset = dataset + + max_length = max(lengths) + min_length = min(lengths) + + assert max_token_limit >= max_length + + buckets = torch.linspace(min_length, max_length, num_buckets) + lengths = torch.tensor(lengths) + bucket_assignments = torch.bucketize(lengths, buckets) + + idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)] + if shuffle: + idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets)) + else: + idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True) + + sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2]) + self.batches = _batch_by_token_count( + [(idx, length) for idx, length, _ in sorted_idx_length_buckets], max_token_limit, sample_limit=sample_limit + ) + + def __getitem__(self, idx): + return [self.dataset[subidx] for subidx in self.batches[idx]] + + def __len__(self): + return len(self.batches) + + +class TransformDataset(torch.utils.data.Dataset): + def __init__(self, dataset, transform_fn): + self.dataset = dataset + self.transform_fn = transform_fn + + def __getitem__(self, idx): + return self.transform_fn(self.dataset[idx]) + + def __len__(self): + return len(self.dataset) + + +class LibriSpeechDataModule(LightningDataModule): + def __init__( + self, + *, + librispeech_path, + train_transform, + val_transform, + test_transform, + max_token_limit=700, + sample_limit=2, + train_num_buckets=50, + train_shuffle=True, + num_workers=10, + ): + self.librispeech_path = librispeech_path + self.train_dataset_lengths = None + self.val_dataset_lengths = None + self.train_transform = train_transform + self.val_transform = val_transform + self.test_transform = test_transform + self.max_token_limit = max_token_limit + self.sample_limit = sample_limit + self.train_num_buckets = train_num_buckets + self.train_shuffle = train_shuffle + self.num_workers = num_workers + + def train_dataloader(self): + datasets = [ + torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"), + torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"), + torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"), + ] + + if not self.train_dataset_lengths: + self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets] + + dataset = torch.utils.data.ConcatDataset( + [ + CustomBucketDataset( + dataset, lengths, self.max_token_limit, self.train_num_buckets, sample_limit=self.sample_limit, + ) + for dataset, lengths in zip(datasets, self.train_dataset_lengths) + ] + ) + dataset = TransformDataset(dataset, self.train_transform) + dataloader = torch.utils.data.DataLoader( + dataset, num_workers=self.num_workers, batch_size=None, shuffle=self.train_shuffle + ) + return dataloader + + def val_dataloader(self): + datasets = [ + torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"), + torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"), + ] + + if not self.val_dataset_lengths: + self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets] + + dataset = torch.utils.data.ConcatDataset( + [ + CustomBucketDataset(dataset, lengths, self.max_token_limit, 1, sample_limit=self.sample_limit) + for dataset, lengths in zip(datasets, self.val_dataset_lengths) + ] + ) + dataset = TransformDataset(dataset, self.val_transform) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers) + return dataloader + + def test_dataloader(self): + dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean") + dataset = TransformDataset(dataset, self.test_transform) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=None) + return dataloader diff --git a/examples/asr/librispeech_conformer_rnnt/eval.py b/examples/asr/librispeech_conformer_rnnt/eval.py index d21e772df8..3bea8ac018 100644 --- a/examples/asr/librispeech_conformer_rnnt/eval.py +++ b/examples/asr/librispeech_conformer_rnnt/eval.py @@ -4,7 +4,7 @@ import torch import torchaudio -from lightning import ConformerRNNTModule +from lightning import ConformerRNNTModule, get_data_module logger = logging.getLogger() @@ -15,19 +15,15 @@ def compute_word_level_distance(seq1, seq2): def run_eval(args): - model = ConformerRNNTModule.load_from_checkpoint( - args.checkpoint_path, - librispeech_path=str(args.librispeech_path), - sp_model_path=str(args.sp_model_path), - global_stats_path=str(args.global_stats_path), - ).eval() + model = ConformerRNNTModule.load_from_checkpoint(args.checkpoint_path, sp_model_path=str(args.sp_model_path)).eval() + data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path)) if args.use_cuda: model = model.to(device="cuda") total_edit_distance = 0 total_length = 0 - dataloader = model.test_dataloader() + dataloader = data_module.test_dataloader() with torch.no_grad(): for idx, (batch, sample) in enumerate(dataloader): actual = sample[0][2] @@ -42,9 +38,7 @@ def run_eval(args): def cli_main(): parser = ArgumentParser() parser.add_argument( - "--checkpoint-path", - type=pathlib.Path, - help="Path to checkpoint to use for evaluation.", + "--checkpoint-path", type=pathlib.Path, help="Path to checkpoint to use for evaluation.", ) parser.add_argument( "--global-stats-path", @@ -53,20 +47,13 @@ def cli_main(): help="Path to JSON file containing feature means and stddevs.", ) parser.add_argument( - "--librispeech-path", - type=pathlib.Path, - help="Path to LibriSpeech datasets.", + "--librispeech-path", type=pathlib.Path, help="Path to LibriSpeech datasets.", ) parser.add_argument( - "--sp-model-path", - type=pathlib.Path, - help="Path to SentencePiece model.", + "--sp-model-path", type=pathlib.Path, help="Path to SentencePiece model.", ) parser.add_argument( - "--use-cuda", - action="store_true", - default=False, - help="Run using CUDA.", + "--use-cuda", action="store_true", default=False, help="Run using CUDA.", ) args = parser.parse_args() run_eval(args) diff --git a/examples/asr/librispeech_conformer_rnnt/lightning.py b/examples/asr/librispeech_conformer_rnnt/lightning.py index 6f0184d5fe..8c2fbb0fc9 100644 --- a/examples/asr/librispeech_conformer_rnnt/lightning.py +++ b/examples/asr/librispeech_conformer_rnnt/lightning.py @@ -1,9 +1,5 @@ -import json import logging import math -import os -import random -from collections import namedtuple from typing import List, Tuple import sentencepiece as spm @@ -12,125 +8,17 @@ from pytorch_lightning import LightningModule, seed_everything from torchaudio.models import Hypothesis, RNNTBeamSearch from torchaudio.prototype.models import conformer_rnnt_base +from data_module import LibriSpeechDataModule +from transforms import Batch, TrainTransform, ValTransform, TestTransform logger = logging.getLogger() seed_everything(1) -Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) - - -_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) -_gain = pow(10, 0.05 * _decibel) - -_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160) _expected_spm_vocab_size = 1023 -def _piecewise_linear_log(x): - x[x > math.e] = torch.log(x[x > math.e]) - x[x <= math.e] = x[x <= math.e] / math.e - return x - - -def _batch_by_token_count(idx_target_lengths, token_limit, sample_limit=None): - batches = [] - current_batch = [] - current_token_count = 0 - for idx, target_length in idx_target_lengths: - if current_token_count + target_length > token_limit or (sample_limit and len(current_batch) == sample_limit): - batches.append(current_batch) - current_batch = [idx] - current_token_count = target_length - else: - current_batch.append(idx) - current_token_count += target_length - - if current_batch: - batches.append(current_batch) - - return batches - - -def get_sample_lengths(librispeech_dataset): - fileid_to_target_length = {} - - def _target_length(fileid): - if fileid not in fileid_to_target_length: - speaker_id, chapter_id, _ = fileid.split("-") - - file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt - file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text) - - with open(file_text) as ft: - for line in ft: - fileid_text, transcript = line.strip().split(" ", 1) - fileid_to_target_length[fileid_text] = len(transcript) - - return fileid_to_target_length[fileid] - - return [_target_length(fileid) for fileid in librispeech_dataset._walker] - - -class CustomBucketDataset(torch.utils.data.Dataset): - def __init__(self, dataset, lengths, max_token_limit, num_buckets, shuffle=False, sample_limit=None): - super().__init__() - - assert len(dataset) == len(lengths) - - self.dataset = dataset - - max_length = max(lengths) - min_length = min(lengths) - - assert max_token_limit >= max_length - - buckets = torch.linspace(min_length, max_length, num_buckets) - lengths = torch.tensor(lengths) - bucket_assignments = torch.bucketize(lengths, buckets) - - idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)] - if shuffle: - idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets)) - else: - idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True) - - sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2]) - self.batches = _batch_by_token_count( - [(idx, length) for idx, length, _ in sorted_idx_length_buckets], max_token_limit, sample_limit=sample_limit - ) - - def __getitem__(self, idx): - return [self.dataset[subidx] for subidx in self.batches[idx]] - - def __len__(self): - return len(self.batches) - - -class FunctionalModule(torch.nn.Module): - def __init__(self, functional): - super().__init__() - self.functional = functional - - def forward(self, input): - return self.functional(input) - - -class GlobalStatsNormalization(torch.nn.Module): - def __init__(self, global_stats_path): - super().__init__() - - with open(global_stats_path) as f: - blob = json.loads(f.read()) - - self.mean = torch.tensor(blob["mean"]) - self.invstddev = torch.tensor(blob["invstddev"]) - - def forward(self, input): - return (input - self.mean) * self.invstddev - - class WarmupLR(torch.optim.lr_scheduler._LRScheduler): r"""Learning rate scheduler that performs linear warmup and exponential annealing. @@ -189,13 +77,7 @@ def post_process_hypos( class ConformerRNNTModule(LightningModule): - def __init__( - self, - *, - librispeech_path: str, - sp_model_path: str, - global_stats_path: str, - ): + def __init__(self, sp_model_path): super().__init__() self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) @@ -214,65 +96,8 @@ def __init__( self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96) - self.train_data_pipeline = torch.nn.Sequential( - FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)), - GlobalStatsNormalization(global_stats_path), - FunctionalModule(lambda x: x.transpose(1, 2)), - torchaudio.transforms.FrequencyMasking(27), - torchaudio.transforms.FrequencyMasking(27), - torchaudio.transforms.TimeMasking(100, p=0.2), - torchaudio.transforms.TimeMasking(100, p=0.2), - FunctionalModule(lambda x: x.transpose(1, 2)), - ) - self.valid_data_pipeline = torch.nn.Sequential( - FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)), - GlobalStatsNormalization(global_stats_path), - ) - - self.librispeech_path = librispeech_path - - self.train_dataset_lengths = None - self.val_dataset_lengths = None - self.automatic_optimization = False - def _extract_labels(self, samples: List): - targets = [self.sp_model.encode(sample[2].lower()) for sample in samples] - lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32) - targets = torch.nn.utils.rnn.pad_sequence( - [torch.tensor(elem) for elem in targets], - batch_first=True, - padding_value=1.0, - ).to(dtype=torch.int32) - return targets, lengths - - def _train_extract_features(self, samples: List): - mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples] - features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True) - features = self.train_data_pipeline(features) - lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32) - return features, lengths - - def _valid_extract_features(self, samples: List): - mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples] - features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True) - features = self.valid_data_pipeline(features) - lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32) - return features, lengths - - def _train_collate_fn(self, samples: List): - features, feature_lengths = self._train_extract_features(samples) - targets, target_lengths = self._extract_labels(samples) - return Batch(features, feature_lengths, targets, target_lengths) - - def _valid_collate_fn(self, samples: List): - features, feature_lengths = self._valid_extract_features(samples) - targets, target_lengths = self._extract_labels(samples) - return Batch(features, feature_lengths, targets, target_lengths) - - def _test_collate_fn(self, samples: List): - return self._valid_collate_fn(samples), samples - def _step(self, batch, _, step_type): if batch is None: return None @@ -348,55 +173,16 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): return self._step(batch, batch_idx, "test") - def train_dataloader(self): - datasets = [ - torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"), - torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"), - torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"), - ] - - if not self.train_dataset_lengths: - self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets] - - dataset = torch.utils.data.ConcatDataset( - [ - CustomBucketDataset(dataset, lengths, 700, 50, shuffle=False, sample_limit=2) - for dataset, lengths in zip(datasets, self.train_dataset_lengths) - ] - ) - dataloader = torch.utils.data.DataLoader( - dataset, - collate_fn=self._train_collate_fn, - num_workers=10, - batch_size=None, - shuffle=True, - ) - return dataloader - - def val_dataloader(self): - datasets = [ - torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"), - torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"), - ] - - if not self.val_dataset_lengths: - self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets] - - dataset = torch.utils.data.ConcatDataset( - [ - CustomBucketDataset(dataset, lengths, 700, 1, sample_limit=2) - for dataset, lengths in zip(datasets, self.val_dataset_lengths) - ] - ) - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=None, - collate_fn=self._valid_collate_fn, - num_workers=10, - ) - return dataloader - def test_dataloader(self): - dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean") - dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn) - return dataloader +def get_data_module(librispeech_path, global_stats_path, sp_model_path): + train_transform = TrainTransform( + global_stats_path=global_stats_path, sp_model_path=sp_model_path + ) + val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path) + test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path) + return LibriSpeechDataModule( + librispeech_path=librispeech_path, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + ) diff --git a/examples/asr/librispeech_conformer_rnnt/train.py b/examples/asr/librispeech_conformer_rnnt/train.py index 1c4415f3e7..0329bf8eb5 100644 --- a/examples/asr/librispeech_conformer_rnnt/train.py +++ b/examples/asr/librispeech_conformer_rnnt/train.py @@ -1,7 +1,7 @@ import pathlib from argparse import ArgumentParser -from lightning import ConformerRNNTModule +from lightning import ConformerRNNTModule, get_data_module from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.plugins import DDPPlugin @@ -10,20 +10,10 @@ def run_train(args): checkpoint_dir = args.exp_dir / "checkpoints" checkpoint = ModelCheckpoint( - checkpoint_dir, - monitor="Losses/val_loss", - mode="min", - save_top_k=5, - save_weights_only=False, - verbose=True, + checkpoint_dir, monitor="Losses/val_loss", mode="min", save_top_k=5, save_weights_only=False, verbose=True, ) train_checkpoint = ModelCheckpoint( - checkpoint_dir, - monitor="Losses/train_loss", - mode="min", - save_top_k=5, - save_weights_only=False, - verbose=True, + checkpoint_dir, monitor="Losses/train_loss", mode="min", save_top_k=5, save_weights_only=False, verbose=True, ) lr_monitor = LearningRateMonitor(logging_interval="step") callbacks = [ @@ -42,12 +32,9 @@ def run_train(args): reload_dataloaders_every_n_epochs=1, ) - model = ConformerRNNTModule( - librispeech_path=str(args.librispeech_path), - sp_model_path=str(args.sp_model_path), - global_stats_path=str(args.global_stats_path), - ) - trainer.fit(model) + model = ConformerRNNTModule(str(args.sp_model_path)) + data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path)) + trainer.fit(model, data_module) def cli_main(): @@ -65,32 +52,19 @@ def cli_main(): help="Path to JSON file containing feature means and stddevs.", ) parser.add_argument( - "--librispeech-path", - type=pathlib.Path, - help="Path to LibriSpeech datasets.", + "--librispeech-path", type=pathlib.Path, help="Path to LibriSpeech datasets.", ) parser.add_argument( - "--sp-model-path", - type=pathlib.Path, - help="Path to SentencePiece model.", + "--sp-model-path", type=pathlib.Path, help="Path to SentencePiece model.", ) parser.add_argument( - "--nodes", - default=4, - type=int, - help="Number of nodes to use for training. (Default: 4)", + "--nodes", default=4, type=int, help="Number of nodes to use for training. (Default: 4)", ) parser.add_argument( - "--gpus", - default=8, - type=int, - help="Number of GPUs per node to use for training. (Default: 8)", + "--gpus", default=8, type=int, help="Number of GPUs per node to use for training. (Default: 8)", ) parser.add_argument( - "--epochs", - default=120, - type=int, - help="Number of epochs to train for. (Default: 120)", + "--epochs", default=120, type=int, help="Number of epochs to train for. (Default: 120)", ) args = parser.parse_args() diff --git a/examples/asr/librispeech_conformer_rnnt/transforms.py b/examples/asr/librispeech_conformer_rnnt/transforms.py new file mode 100644 index 0000000000..37b079745c --- /dev/null +++ b/examples/asr/librispeech_conformer_rnnt/transforms.py @@ -0,0 +1,104 @@ +import json +import math +import torch +import torchaudio +from collections import namedtuple +import sentencepiece as spm +from typing import List + + +Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) + + +_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) +_gain = pow(10, 0.05 * _decibel) + +_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160) + + +def _piecewise_linear_log(x): + x[x > math.e] = torch.log(x[x > math.e]) + x[x <= math.e] = x[x <= math.e] / math.e + return x + + +class FunctionalModule(torch.nn.Module): + def __init__(self, functional): + super().__init__() + self.functional = functional + + def forward(self, input): + return self.functional(input) + + +class GlobalStatsNormalization(torch.nn.Module): + def __init__(self, global_stats_path): + super().__init__() + + with open(global_stats_path) as f: + blob = json.loads(f.read()) + + self.mean = torch.tensor(blob["mean"]) + self.invstddev = torch.tensor(blob["invstddev"]) + + def forward(self, input): + return (input - self.mean) * self.invstddev + + +def _extract_labels(sp_model, samples: List): + targets = [sp_model.encode(sample[2].lower()) for sample in samples] + lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32) + targets = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(elem) for elem in targets], batch_first=True, padding_value=1.0, + ).to(dtype=torch.int32) + return targets, lengths + + +def _extract_features(data_pipeline, samples: List): + mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples] + features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True) + features = data_pipeline(features) + lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32) + return features, lengths + + +class TrainTransform: + def __init__(self, global_stats_path: str, sp_model_path: str): + self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) + self.train_data_pipeline = torch.nn.Sequential( + FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)), + GlobalStatsNormalization(global_stats_path), + FunctionalModule(lambda x: x.transpose(1, 2)), + torchaudio.transforms.FrequencyMasking(27), + torchaudio.transforms.FrequencyMasking(27), + torchaudio.transforms.TimeMasking(100, p=0.2), + torchaudio.transforms.TimeMasking(100, p=0.2), + FunctionalModule(lambda x: x.transpose(1, 2)), + ) + + def __call__(self, samples: List): + features, feature_lengths = _extract_features(self.train_data_pipeline, samples) + targets, target_lengths = _extract_labels(self.sp_model, samples) + return Batch(features, feature_lengths, targets, target_lengths) + + +class ValTransform: + def __init__(self, global_stats_path: str, sp_model_path: str): + self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) + self.valid_data_pipeline = torch.nn.Sequential( + FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)), + GlobalStatsNormalization(global_stats_path), + ) + + def __call__(self, samples: List): + features, feature_lengths = _extract_features(self.valid_data_pipeline, samples) + targets, target_lengths = _extract_labels(self.sp_model, samples) + return Batch(features, feature_lengths, targets, target_lengths) + + +class TestTransform: + def __init__(self, global_stats_path: str, sp_model_path: str): + self.val_transforms = ValTransform(global_stats_path, sp_model_path) + + def __call__(self, sample): + return self.val_transforms([sample]), [sample] From 0078c68d379bbbfe3afcece8211c265546d45d96 Mon Sep 17 00:00:00 2001 From: hwangjeff Date: Fri, 6 May 2022 20:27:58 +0000 Subject: [PATCH 2/4] lint and partial --- .../librispeech_conformer_rnnt/data_module.py | 39 +++++++++++++++---- .../asr/librispeech_conformer_rnnt/eval.py | 17 ++++++-- .../librispeech_conformer_rnnt/lightning.py | 6 +-- .../asr/librispeech_conformer_rnnt/train.py | 37 ++++++++++++++---- .../librispeech_conformer_rnnt/transforms.py | 21 ++++++---- 5 files changed, 89 insertions(+), 31 deletions(-) diff --git a/examples/asr/librispeech_conformer_rnnt/data_module.py b/examples/asr/librispeech_conformer_rnnt/data_module.py index 291b036ab8..2d5115408d 100644 --- a/examples/asr/librispeech_conformer_rnnt/data_module.py +++ b/examples/asr/librispeech_conformer_rnnt/data_module.py @@ -1,10 +1,10 @@ +import os +import random + import torch import torchaudio from pytorch_lightning import LightningDataModule, seed_everything -import os -import random - seed_everything(1) @@ -49,7 +49,15 @@ def _target_length(fileid): class CustomBucketDataset(torch.utils.data.Dataset): - def __init__(self, dataset, lengths, max_token_limit, num_buckets, shuffle=False, sample_limit=None): + def __init__( + self, + dataset, + lengths, + max_token_limit, + num_buckets, + shuffle=False, + sample_limit=None, + ): super().__init__() assert len(dataset) == len(lengths) @@ -73,7 +81,9 @@ def __init__(self, dataset, lengths, max_token_limit, num_buckets, shuffle=False sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2]) self.batches = _batch_by_token_count( - [(idx, length) for idx, length, _ in sorted_idx_length_buckets], max_token_limit, sample_limit=sample_limit + [(idx, length) for idx, length, _ in sorted_idx_length_buckets], + max_token_limit, + sample_limit=sample_limit, ) def __getitem__(self, idx): @@ -134,14 +144,21 @@ def train_dataloader(self): dataset = torch.utils.data.ConcatDataset( [ CustomBucketDataset( - dataset, lengths, self.max_token_limit, self.train_num_buckets, sample_limit=self.sample_limit, + dataset, + lengths, + self.max_token_limit, + self.train_num_buckets, + sample_limit=self.sample_limit, ) for dataset, lengths in zip(datasets, self.train_dataset_lengths) ] ) dataset = TransformDataset(dataset, self.train_transform) dataloader = torch.utils.data.DataLoader( - dataset, num_workers=self.num_workers, batch_size=None, shuffle=self.train_shuffle + dataset, + num_workers=self.num_workers, + batch_size=None, + shuffle=self.train_shuffle, ) return dataloader @@ -156,7 +173,13 @@ def val_dataloader(self): dataset = torch.utils.data.ConcatDataset( [ - CustomBucketDataset(dataset, lengths, self.max_token_limit, 1, sample_limit=self.sample_limit) + CustomBucketDataset( + dataset, + lengths, + self.max_token_limit, + 1, + sample_limit=self.sample_limit, + ) for dataset, lengths in zip(datasets, self.val_dataset_lengths) ] ) diff --git a/examples/asr/librispeech_conformer_rnnt/eval.py b/examples/asr/librispeech_conformer_rnnt/eval.py index 3bea8ac018..bb142819f4 100644 --- a/examples/asr/librispeech_conformer_rnnt/eval.py +++ b/examples/asr/librispeech_conformer_rnnt/eval.py @@ -38,7 +38,9 @@ def run_eval(args): def cli_main(): parser = ArgumentParser() parser.add_argument( - "--checkpoint-path", type=pathlib.Path, help="Path to checkpoint to use for evaluation.", + "--checkpoint-path", + type=pathlib.Path, + help="Path to checkpoint to use for evaluation.", ) parser.add_argument( "--global-stats-path", @@ -47,13 +49,20 @@ def cli_main(): help="Path to JSON file containing feature means and stddevs.", ) parser.add_argument( - "--librispeech-path", type=pathlib.Path, help="Path to LibriSpeech datasets.", + "--librispeech-path", + type=pathlib.Path, + help="Path to LibriSpeech datasets.", ) parser.add_argument( - "--sp-model-path", type=pathlib.Path, help="Path to SentencePiece model.", + "--sp-model-path", + type=pathlib.Path, + help="Path to SentencePiece model.", ) parser.add_argument( - "--use-cuda", action="store_true", default=False, help="Run using CUDA.", + "--use-cuda", + action="store_true", + default=False, + help="Run using CUDA.", ) args = parser.parse_args() run_eval(args) diff --git a/examples/asr/librispeech_conformer_rnnt/lightning.py b/examples/asr/librispeech_conformer_rnnt/lightning.py index 8c2fbb0fc9..2abe78b9e3 100644 --- a/examples/asr/librispeech_conformer_rnnt/lightning.py +++ b/examples/asr/librispeech_conformer_rnnt/lightning.py @@ -5,10 +5,10 @@ import sentencepiece as spm import torch import torchaudio +from data_module import LibriSpeechDataModule from pytorch_lightning import LightningModule, seed_everything from torchaudio.models import Hypothesis, RNNTBeamSearch from torchaudio.prototype.models import conformer_rnnt_base -from data_module import LibriSpeechDataModule from transforms import Batch, TrainTransform, ValTransform, TestTransform logger = logging.getLogger() @@ -175,9 +175,7 @@ def test_step(self, batch, batch_idx): def get_data_module(librispeech_path, global_stats_path, sp_model_path): - train_transform = TrainTransform( - global_stats_path=global_stats_path, sp_model_path=sp_model_path - ) + train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path) val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path) test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path) return LibriSpeechDataModule( diff --git a/examples/asr/librispeech_conformer_rnnt/train.py b/examples/asr/librispeech_conformer_rnnt/train.py index 0329bf8eb5..733a668109 100644 --- a/examples/asr/librispeech_conformer_rnnt/train.py +++ b/examples/asr/librispeech_conformer_rnnt/train.py @@ -10,10 +10,20 @@ def run_train(args): checkpoint_dir = args.exp_dir / "checkpoints" checkpoint = ModelCheckpoint( - checkpoint_dir, monitor="Losses/val_loss", mode="min", save_top_k=5, save_weights_only=False, verbose=True, + checkpoint_dir, + monitor="Losses/val_loss", + mode="min", + save_top_k=5, + save_weights_only=False, + verbose=True, ) train_checkpoint = ModelCheckpoint( - checkpoint_dir, monitor="Losses/train_loss", mode="min", save_top_k=5, save_weights_only=False, verbose=True, + checkpoint_dir, + monitor="Losses/train_loss", + mode="min", + save_top_k=5, + save_weights_only=False, + verbose=True, ) lr_monitor = LearningRateMonitor(logging_interval="step") callbacks = [ @@ -52,19 +62,32 @@ def cli_main(): help="Path to JSON file containing feature means and stddevs.", ) parser.add_argument( - "--librispeech-path", type=pathlib.Path, help="Path to LibriSpeech datasets.", + "--librispeech-path", + type=pathlib.Path, + help="Path to LibriSpeech datasets.", ) parser.add_argument( - "--sp-model-path", type=pathlib.Path, help="Path to SentencePiece model.", + "--sp-model-path", + type=pathlib.Path, + help="Path to SentencePiece model.", ) parser.add_argument( - "--nodes", default=4, type=int, help="Number of nodes to use for training. (Default: 4)", + "--nodes", + default=4, + type=int, + help="Number of nodes to use for training. (Default: 4)", ) parser.add_argument( - "--gpus", default=8, type=int, help="Number of GPUs per node to use for training. (Default: 8)", + "--gpus", + default=8, + type=int, + help="Number of GPUs per node to use for training. (Default: 8)", ) parser.add_argument( - "--epochs", default=120, type=int, help="Number of epochs to train for. (Default: 120)", + "--epochs", + default=120, + type=int, + help="Number of epochs to train for. (Default: 120)", ) args = parser.parse_args() diff --git a/examples/asr/librispeech_conformer_rnnt/transforms.py b/examples/asr/librispeech_conformer_rnnt/transforms.py index 37b079745c..80531ea373 100644 --- a/examples/asr/librispeech_conformer_rnnt/transforms.py +++ b/examples/asr/librispeech_conformer_rnnt/transforms.py @@ -1,11 +1,13 @@ import json import math -import torch -import torchaudio from collections import namedtuple -import sentencepiece as spm +from functools import partial from typing import List +import sentencepiece as spm +import torch +import torchaudio + Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) @@ -17,6 +19,7 @@ def _piecewise_linear_log(x): + x = x * _gain x[x > math.e] = torch.log(x[x > math.e]) x[x <= math.e] = x[x <= math.e] / math.e return x @@ -49,7 +52,9 @@ def _extract_labels(sp_model, samples: List): targets = [sp_model.encode(sample[2].lower()) for sample in samples] lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32) targets = torch.nn.utils.rnn.pad_sequence( - [torch.tensor(elem) for elem in targets], batch_first=True, padding_value=1.0, + [torch.tensor(elem) for elem in targets], + batch_first=True, + padding_value=1.0, ).to(dtype=torch.int32) return targets, lengths @@ -66,14 +71,14 @@ class TrainTransform: def __init__(self, global_stats_path: str, sp_model_path: str): self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) self.train_data_pipeline = torch.nn.Sequential( - FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)), + FunctionalModule(_piecewise_linear_log), GlobalStatsNormalization(global_stats_path), - FunctionalModule(lambda x: x.transpose(1, 2)), + FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)), torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2), - FunctionalModule(lambda x: x.transpose(1, 2)), + FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)), ) def __call__(self, samples: List): @@ -86,7 +91,7 @@ class ValTransform: def __init__(self, global_stats_path: str, sp_model_path: str): self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) self.valid_data_pipeline = torch.nn.Sequential( - FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)), + FunctionalModule(_piecewise_linear_log), GlobalStatsNormalization(global_stats_path), ) From 3196a2ec8e654a24f28ee3aca8e34a7ad5602b32 Mon Sep 17 00:00:00 2001 From: hwangjeff Date: Sat, 7 May 2022 03:20:56 +0000 Subject: [PATCH 3/4] allow for restarting training from checkpoint; annotate required args --- examples/asr/librispeech_conformer_rnnt/eval.py | 3 +++ examples/asr/librispeech_conformer_rnnt/train.py | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/examples/asr/librispeech_conformer_rnnt/eval.py b/examples/asr/librispeech_conformer_rnnt/eval.py index bb142819f4..15c4cb6646 100644 --- a/examples/asr/librispeech_conformer_rnnt/eval.py +++ b/examples/asr/librispeech_conformer_rnnt/eval.py @@ -41,6 +41,7 @@ def cli_main(): "--checkpoint-path", type=pathlib.Path, help="Path to checkpoint to use for evaluation.", + required=True, ) parser.add_argument( "--global-stats-path", @@ -52,11 +53,13 @@ def cli_main(): "--librispeech-path", type=pathlib.Path, help="Path to LibriSpeech datasets.", + required=True, ) parser.add_argument( "--sp-model-path", type=pathlib.Path, help="Path to SentencePiece model.", + required=True, ) parser.add_argument( "--use-cuda", diff --git a/examples/asr/librispeech_conformer_rnnt/train.py b/examples/asr/librispeech_conformer_rnnt/train.py index 733a668109..cf3324359f 100644 --- a/examples/asr/librispeech_conformer_rnnt/train.py +++ b/examples/asr/librispeech_conformer_rnnt/train.py @@ -44,11 +44,17 @@ def run_train(args): model = ConformerRNNTModule(str(args.sp_model_path)) data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path)) - trainer.fit(model, data_module) + trainer.fit(model, data_module, ckpt_path=args.checkpoint_path) def cli_main(): parser = ArgumentParser() + parser.add_argument( + "--checkpoint-path", + default=None, + type=pathlib.Path, + help="Path to checkpoint to use for evaluation.", + ) parser.add_argument( "--exp-dir", default=pathlib.Path("./exp"), @@ -65,11 +71,13 @@ def cli_main(): "--librispeech-path", type=pathlib.Path, help="Path to LibriSpeech datasets.", + required=True, ) parser.add_argument( "--sp-model-path", type=pathlib.Path, help="Path to SentencePiece model.", + required=True, ) parser.add_argument( "--nodes", From 8f08ebe87005c39d9b99d344fc78c0ab8592f26f Mon Sep 17 00:00:00 2001 From: hwangjeff Date: Wed, 11 May 2022 03:46:08 +0000 Subject: [PATCH 4/4] address feedback --- .../librispeech_conformer_rnnt/data_module.py | 35 +++++++++---------- .../librispeech_conformer_rnnt/lightning.py | 5 +-- .../asr/librispeech_conformer_rnnt/train.py | 4 +-- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/examples/asr/librispeech_conformer_rnnt/data_module.py b/examples/asr/librispeech_conformer_rnnt/data_module.py index 2d5115408d..e20b715948 100644 --- a/examples/asr/librispeech_conformer_rnnt/data_module.py +++ b/examples/asr/librispeech_conformer_rnnt/data_module.py @@ -3,18 +3,15 @@ import torch import torchaudio -from pytorch_lightning import LightningDataModule, seed_everything +from pytorch_lightning import LightningDataModule -seed_everything(1) - - -def _batch_by_token_count(idx_target_lengths, token_limit, sample_limit=None): +def _batch_by_token_count(idx_target_lengths, max_tokens, batch_size=None): batches = [] current_batch = [] current_token_count = 0 for idx, target_length in idx_target_lengths: - if current_token_count + target_length > token_limit or (sample_limit and len(current_batch) == sample_limit): + if current_token_count + target_length > max_tokens or (batch_size and len(current_batch) == batch_size): batches.append(current_batch) current_batch = [idx] current_token_count = target_length @@ -53,10 +50,10 @@ def __init__( self, dataset, lengths, - max_token_limit, + max_tokens, num_buckets, shuffle=False, - sample_limit=None, + batch_size=None, ): super().__init__() @@ -67,7 +64,7 @@ def __init__( max_length = max(lengths) min_length = min(lengths) - assert max_token_limit >= max_length + assert max_tokens >= max_length buckets = torch.linspace(min_length, max_length, num_buckets) lengths = torch.tensor(lengths) @@ -82,8 +79,8 @@ def __init__( sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2]) self.batches = _batch_by_token_count( [(idx, length) for idx, length, _ in sorted_idx_length_buckets], - max_token_limit, - sample_limit=sample_limit, + max_tokens, + batch_size=batch_size, ) def __getitem__(self, idx): @@ -113,8 +110,8 @@ def __init__( train_transform, val_transform, test_transform, - max_token_limit=700, - sample_limit=2, + max_tokens=700, + batch_size=2, train_num_buckets=50, train_shuffle=True, num_workers=10, @@ -125,8 +122,8 @@ def __init__( self.train_transform = train_transform self.val_transform = val_transform self.test_transform = test_transform - self.max_token_limit = max_token_limit - self.sample_limit = sample_limit + self.max_tokens = max_tokens + self.batch_size = batch_size self.train_num_buckets = train_num_buckets self.train_shuffle = train_shuffle self.num_workers = num_workers @@ -146,9 +143,9 @@ def train_dataloader(self): CustomBucketDataset( dataset, lengths, - self.max_token_limit, + self.max_tokens, self.train_num_buckets, - sample_limit=self.sample_limit, + batch_size=self.batch_size, ) for dataset, lengths in zip(datasets, self.train_dataset_lengths) ] @@ -176,9 +173,9 @@ def val_dataloader(self): CustomBucketDataset( dataset, lengths, - self.max_token_limit, + self.max_tokens, 1, - sample_limit=self.sample_limit, + batch_size=self.batch_size, ) for dataset, lengths in zip(datasets, self.val_dataset_lengths) ] diff --git a/examples/asr/librispeech_conformer_rnnt/lightning.py b/examples/asr/librispeech_conformer_rnnt/lightning.py index 2abe78b9e3..ea62aec41f 100644 --- a/examples/asr/librispeech_conformer_rnnt/lightning.py +++ b/examples/asr/librispeech_conformer_rnnt/lightning.py @@ -6,16 +6,13 @@ import torch import torchaudio from data_module import LibriSpeechDataModule -from pytorch_lightning import LightningModule, seed_everything +from pytorch_lightning import LightningModule from torchaudio.models import Hypothesis, RNNTBeamSearch from torchaudio.prototype.models import conformer_rnnt_base from transforms import Batch, TrainTransform, ValTransform, TestTransform logger = logging.getLogger() -seed_everything(1) - - _expected_spm_vocab_size = 1023 diff --git a/examples/asr/librispeech_conformer_rnnt/train.py b/examples/asr/librispeech_conformer_rnnt/train.py index cf3324359f..9895460396 100644 --- a/examples/asr/librispeech_conformer_rnnt/train.py +++ b/examples/asr/librispeech_conformer_rnnt/train.py @@ -2,12 +2,13 @@ from argparse import ArgumentParser from lightning import ConformerRNNTModule, get_data_module -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.plugins import DDPPlugin def run_train(args): + seed_everything(1) checkpoint_dir = args.exp_dir / "checkpoints" checkpoint = ModelCheckpoint( checkpoint_dir, @@ -98,7 +99,6 @@ def cli_main(): help="Number of epochs to train for. (Default: 120)", ) args = parser.parse_args() - run_train(args)