From 9effd276510d8a6530a5def9a5f1d9590f2b7144 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Thu, 3 Feb 2022 14:01:04 +0000 Subject: [PATCH 1/7] add recipe for HuBERT model pre-training --- examples/hubert/dataset/hubert_dataset.py | 10 +- examples/hubert/lightning.py | 133 +++++++++++++++ examples/hubert/loss/__init__.py | 5 + examples/hubert/loss/hubert_loss.py | 36 ++++ examples/hubert/train.py | 196 ++++++++++++++++++++++ 5 files changed, 375 insertions(+), 5 deletions(-) create mode 100644 examples/hubert/lightning.py create mode 100644 examples/hubert/loss/__init__.py create mode 100644 examples/hubert/loss/hubert_loss.py diff --git a/examples/hubert/dataset/hubert_dataset.py b/examples/hubert/dataset/hubert_dataset.py index 5dc37c8ae5..97c559a821 100644 --- a/examples/hubert/dataset/hubert_dataset.py +++ b/examples/hubert/dataset/hubert_dataset.py @@ -215,20 +215,20 @@ class HuBERTDataSet(Dataset): """Create a Dataset for HuBERT model training and fine-tuning. Args: - exp_dir (str or Path): The root directory of the ``.tsv`` file list. + root_dir (str or Path): The root directory that contains ``tsv`` and ``label`` directories. dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``]. subset (str): The subset of the dataset. Options: [``train``, ``valid``]. """ def __init__( self, - exp_dir: Union[str, Path], + root_dir: Union[str, Path], dataset: str, subset: str, ) -> None: - self.exp_dir = Path(exp_dir) - tsv_dir = self.exp_dir / "tsv" - label_dir = self.exp_dir / "label" + self.root_dir = Path(root_dir) + tsv_dir = self.root_dir / "tsv" + label_dir = self.root_dir / "label" f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset) self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list self.labels = self._load_labels(label_dir, dataset, subset) diff --git a/examples/hubert/lightning.py b/examples/hubert/lightning.py new file mode 100644 index 0000000000..2b6eaaeedd --- /dev/null +++ b/examples/hubert/lightning.py @@ -0,0 +1,133 @@ +from typing import Tuple + +import torch +import torchaudio +from dataset import BucketizeSampler, DistributedBatchSampler, HuBERTDataSet, CollateFnHubert +from loss import hubert_loss +from pytorch_lightning import LightningModule +from torch import Tensor +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + + +Batch = Tuple[Tensor, Tensor, Tensor] + + +class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler): + """Linear learning rate scheduler with warm up.""" + + def __init__( + self, + optimizer: Optimizer, + warmup_updates: int, + max_updates: int, + last_epoch: int = -1, + verbose: bool = False, + ): + self.warmup_updates = warmup_updates + self.max_updates = max_updates + super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) + + def get_lr(self): + if self._step_count <= self.warmup_updates: + return [self._step_count / self.warmup_updates * base_lr for base_lr in self.base_lrs] + elif self._step_count >= self.max_updates: + return [0.0 for _ in self.base_lrs] + else: + pct_remaining = (self.max_updates - self._step_count) / (self.max_updates - self.warmup_updates) + return [base_lr * pct_remaining for base_lr in self.base_lrs] + + +class HuBERTPreTrainModule(LightningModule): + def __init__( + self, + *, + model_name: str, + num_classes: int, + dataset: str, + root_path: str, + feature_type: str, + seconds_per_batch: float, + learning_rate: float, + betas: Tuple[float, float], + eps: float, + weight_decay: float, + warmup_updates: int, + max_updates: int, + ): + super().__init__() + + if model_name == "hubert_pretrain_base": + self.model = torchaudio.models.hubert_pretrain_base(num_classes=num_classes) + elif model_name == "hubert_pretrain_large": + self.model = torchaudio.models.hubert_pretrain_large() + elif model_name == "hubert_pretrain_xlarge": + self.model = torchaudio.models.hubert_pretrain_xlarge() + else: + raise ValueError(f"Unsupported model name: {model_name}") + + self.loss = hubert_loss + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay + ) + self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates) + self.dataset = dataset + self.root_path = root_path + self.feature_type = feature_type + self.seconds_per_batch = seconds_per_batch + + def _step(self, batch, batch_idx, step_type): + if batch is None: + return None + waveforms, labels, audio_lengths = batch + logit_m, logit_u, feature_pen = self.model( + waveforms, + labels, + audio_lengths, + ) + loss = self.loss(logit_m, logit_u, feature_pen) + self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True) + return loss + + def configure_optimizers(self): + return ( + [self.optimizer], + [ + { + "scheduler": self.lr_scheduler, + "interval": "step", + }, + ], + ) + + def training_step(self, batch: Batch, batch_idx): + return self._step(batch, batch_idx, "train") + + def validation_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "val") + + def train_dataloader(self): + dataset = HuBERTDataSet(self.root_path, self.dataset, "train") + sampler = BucketizeSampler(dataset, num_buckets=1000, max_token_count=self.seconds_per_batch * 16000) + sampler = DistributedBatchSampler(sampler) + dataloader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True), + num_workers=10, + pin_memory=True, + ) + return dataloader + + def val_dataloader(self): + dataset = HuBERTDataSet(self.root_path, self.dataset, "valid") + sampler = BucketizeSampler(dataset, num_buckets=1000, max_token_count=self.seconds_per_batch * 16000) + sampler = DistributedBatchSampler(sampler) + dataloader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True), + num_workers=10, + pin_memory=True, + ) + return dataloader diff --git a/examples/hubert/loss/__init__.py b/examples/hubert/loss/__init__.py new file mode 100644 index 0000000000..ad252d8158 --- /dev/null +++ b/examples/hubert/loss/__init__.py @@ -0,0 +1,5 @@ +from .hubert_loss import hubert_loss + +__all__ = [ + "hubert_loss", +] diff --git a/examples/hubert/loss/hubert_loss.py b/examples/hubert/loss/hubert_loss.py new file mode 100644 index 0000000000..a3d5b93e18 --- /dev/null +++ b/examples/hubert/loss/hubert_loss.py @@ -0,0 +1,36 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def hubert_loss( + logit_m: Optional[Tensor], + logit_u: Optional[Tensor], + feature_pen: Tensor, + masked_weight: float = 1.0, + nomask_weight: float = 0.0, + feature_weight: float = 10.0, + reduction: str = "sum", +) -> Tensor: + """Compute the cross-entropy loss on HuBERT masked and non-masked logits. + Args: + logit_m (Tensor or None): The masked logit Tensor of dimension `[masked_frames, final_dim]`. + logit_u (Tensor or None): The non-masked logit Tensor of dimension `[nonmasked_frames, final_dim]`. + feature_pen (Tensor): The feature mean value for additional penalty loss. + masked_weight (float, optional): The weight for masked cross-entropy loss (Default: ``1.0``). + nomask_weight (float, optional): The weight for non-masked cross-entropy loss (Default: ``0.0``). + feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``). + reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``). + """ + loss = feature_pen * feature_weight + if logit_m is not None: + target_m = torch.zeros(logit_m.shape[0], dtype=torch.long, device=logit_m.device) + loss_m = F.cross_entropy(logit_m, target_m, reduction=reduction) + loss += loss_m * masked_weight + if logit_u is not None: + target_u = torch.zeros(logit_u.shape[0], dtype=torch.long, device=logit_m.device) + loss_u = F.cross_entropy(logit_u, target_u, reduction=reduction) + loss += loss_u * nomask_weight + return loss diff --git a/examples/hubert/train.py b/examples/hubert/train.py index e69de29bb2..66a2e1096d 100644 --- a/examples/hubert/train.py +++ b/examples/hubert/train.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +"""Train the HuBERTPretrain model by using 960 hours of LibriSpeech training sets. +Example: +python train.py --root-path ./exp/data/mfcc/ --feature-type mfcc --num-classes 100 +""" + +import logging +import pathlib +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter +from typing import Tuple + +from lightning import HuBERTPreTrainModule +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint + + +logger = logging.getLogger(__name__) + + +class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): + # https://stackoverflow.com/a/18462760 + pass + + +def run_train(args): + checkpoint_dir = args.exp_dir / f"checkpoints_{args.dataset}_{args.model_name}" + checkpoint = ModelCheckpoint( + checkpoint_dir, + monitor="Losses/val_loss", + mode="min", + save_top_k=5, + save_weights_only=True, + verbose=True, + ) + train_checkpoint = ModelCheckpoint( + checkpoint_dir, + monitor="Losses/train_loss", + mode="min", + save_top_k=5, + save_weights_only=True, + verbose=True, + ) + callbacks = [ + checkpoint, + train_checkpoint, + ] + trainer = Trainer( + default_root_dir=args.exp_dir, + max_steps=args.max_updates, + num_nodes=args.num_nodes, + gpus=args.gpus, + accelerator="gpu", + strategy="ddp", + replace_sampler_ddp=False, + gradient_clip_val=args.clip_norm, + callbacks=callbacks, + ) + + model = HuBERTPreTrainModule( + model_name=args.model_name, + num_classes=args.num_classes, + dataset=args.dataset, + root_path=args.root_path, + feature_type=args.feature_type, + seconds_per_batch=args.seconds_per_batch, + learning_rate=args.learning_rate, + betas=args.betas, + eps=args.eps, + weight_decay=args.weight_decay, + warmup_updates=args.warmup_updates, + max_updates=args.max_updates, + ) + trainer.fit(model) + + +def _parse_args(): + parser = ArgumentParser( + description=__doc__, + formatter_class=_Formatter, + ) + parser.add_argument( + "--root-path", + type=pathlib.Path, + required=True, + help="Path to the feature and label directories.", + ) + parser.add_argument( + "--feature-type", + choices=["mfcc", "hubert"], + type=str, + required=True, + ) + parser.add_argument( + "--num-classes", + choices=[100, 500], + type=int, + required=True, + help="The ``num_class`` when building the hubert_pretrain_base model.", + ) + parser.add_argument( + "--model-name", + default="hubert_pretrain_base", + choices=["hubert_pretrain_base", "hubert_pretrain_large", "hubert_pretrain_xlarge"], + type=str, + help="The HuBERT model to train.", + ) + parser.add_argument( + "--exp-dir", + default=pathlib.Path("./exp"), + type=pathlib.Path, + help="Directory to save checkpoints and logs to. (Default: './exp')", + ) + parser.add_argument( + "--dataset", + default="librispeech", + choices=["librispeech", "librilight"], + type=str, + help="The dataset for training. (Default: 'librispeech')", + ) + parser.add_argument( + "--learning-rate", + default=0.003, + type=float, + ) + parser.add_argument( + "--betas", + default=(0.9, 0.98), + type=Tuple, + help=" coefficients for computing running averages of gradient and its square (default: (0.9, 0.98))", + ) + parser.add_argument( + "--eps", + default=1e-6, + type=float, + help="Epsilon value in Adam optimizer. (Default: 1e-6)", + ) + parser.add_argument( + "--weight-decay", + default=0.01, + type=float, + help="Weight decay (L2 penalty) (default: 0.01)", + ) + parser.add_argument( + "--clip-norm", + default=1.0, + type=float, + help="The gradient norm value to clip. (Default: 1.0)", + ) + parser.add_argument( + "--num_nodes", + default=1, + type=int, + help="Number of nodes to use for training. (Default: 1)", + ) + parser.add_argument( + "--gpus", + default=8, + type=int, + help="Number of GPUs per node to use for training. (Default: 8)", + ) + parser.add_argument( + "--warmup-updates", + default=32000, + type=int, + help="Number of steps for warm up the learning rate. (Default: 32000)", + ) + parser.add_argument( + "--max-updates", + default=250000, + type=int, + help="Total number of training steps. (Default: 250000)", + ) + parser.add_argument( + "--seconds-per-batch", + default=87.5, + type=float, + help="Number of seconds of audio in a mini-batch. (Default: 87.5)", + ) + parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") + return parser.parse_args() + + +def _init_logger(debug): + fmt = "%(asctime)s %(message)s" if debug else "%(message)s" + level = logging.DEBUG if debug else logging.INFO + logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") + + +def cli_main(): + args = _parse_args() + _init_logger(args.debug) + run_train(args) + + +if __name__ == "__main__": + cli_main() From 3e40b5bf72084b9e34e5c0f4fe32a0f2f2a711b3 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Wed, 30 Mar 2022 17:59:23 +0100 Subject: [PATCH 2/7] fix --- examples/hubert/dataset/hubert_dataset.py | 10 ++++---- examples/hubert/lightning.py | 30 ++++++++++++++++------- examples/hubert/loss/hubert_loss.py | 2 +- examples/hubert/train.py | 13 +++------- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/examples/hubert/dataset/hubert_dataset.py b/examples/hubert/dataset/hubert_dataset.py index 97c559a821..5dc37c8ae5 100644 --- a/examples/hubert/dataset/hubert_dataset.py +++ b/examples/hubert/dataset/hubert_dataset.py @@ -215,20 +215,20 @@ class HuBERTDataSet(Dataset): """Create a Dataset for HuBERT model training and fine-tuning. Args: - root_dir (str or Path): The root directory that contains ``tsv`` and ``label`` directories. + exp_dir (str or Path): The root directory of the ``.tsv`` file list. dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``]. subset (str): The subset of the dataset. Options: [``train``, ``valid``]. """ def __init__( self, - root_dir: Union[str, Path], + exp_dir: Union[str, Path], dataset: str, subset: str, ) -> None: - self.root_dir = Path(root_dir) - tsv_dir = self.root_dir / "tsv" - label_dir = self.root_dir / "label" + self.exp_dir = Path(exp_dir) + tsv_dir = self.exp_dir / "tsv" + label_dir = self.exp_dir / "label" f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset) self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list self.labels = self._load_labels(label_dir, dataset, subset) diff --git a/examples/hubert/lightning.py b/examples/hubert/lightning.py index 2b6eaaeedd..0b2c30751b 100644 --- a/examples/hubert/lightning.py +++ b/examples/hubert/lightning.py @@ -2,7 +2,7 @@ import torch import torchaudio -from dataset import BucketizeSampler, DistributedBatchSampler, HuBERTDataSet, CollateFnHubert +from dataset import BucketizeBatchSampler, DistributedBatchSampler, HuBERTDataSet, CollateFnHubert from loss import hubert_loss from pytorch_lightning import LightningModule from torch import Tensor @@ -76,7 +76,7 @@ def __init__( self.feature_type = feature_type self.seconds_per_batch = seconds_per_batch - def _step(self, batch, batch_idx, step_type): + def _step(self, batch: Batch, batch_idx, step_type): if batch is None: return None waveforms, labels, audio_lengths = batch @@ -103,31 +103,43 @@ def configure_optimizers(self): def training_step(self, batch: Batch, batch_idx): return self._step(batch, batch_idx, "train") - def validation_step(self, batch, batch_idx): + def validation_step(self, batch: Batch, batch_idx): return self._step(batch, batch_idx, "val") def train_dataloader(self): dataset = HuBERTDataSet(self.root_path, self.dataset, "train") - sampler = BucketizeSampler(dataset, num_buckets=1000, max_token_count=self.seconds_per_batch * 16000) - sampler = DistributedBatchSampler(sampler) + sampler = BucketizeBatchSampler( + dataset.len_list, + num_buckets=10000, + max_token_count=self.seconds_per_batch * 16000, + min_len=32000, + max_len=250000, + shuffle=True, + ) + sampler = DistributedBatchSampler(sampler, shuffle=True) + sampler.set_epoch(self.current_epoch) dataloader = DataLoader( dataset, batch_sampler=sampler, collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True), num_workers=10, - pin_memory=True, ) return dataloader def val_dataloader(self): dataset = HuBERTDataSet(self.root_path, self.dataset, "valid") - sampler = BucketizeSampler(dataset, num_buckets=1000, max_token_count=self.seconds_per_batch * 16000) - sampler = DistributedBatchSampler(sampler) + sampler = BucketizeBatchSampler( + dataset.len_list, + num_buckets=1000, + max_token_count=self.seconds_per_batch * 16000, + min_len=32000, + max_len=250000, + shuffle=False, + ) dataloader = DataLoader( dataset, batch_sampler=sampler, collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True), num_workers=10, - pin_memory=True, ) return dataloader diff --git a/examples/hubert/loss/hubert_loss.py b/examples/hubert/loss/hubert_loss.py index a3d5b93e18..b8cc936674 100644 --- a/examples/hubert/loss/hubert_loss.py +++ b/examples/hubert/loss/hubert_loss.py @@ -24,7 +24,7 @@ def hubert_loss( feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``). reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``). """ - loss = feature_pen * feature_weight + loss = feature_pen * feature_weight * logit_m.shape[0] if logit_m is not None: target_m = torch.zeros(logit_m.shape[0], dtype=torch.long, device=logit_m.device) loss_m = F.cross_entropy(logit_m, target_m, reduction=reduction) diff --git a/examples/hubert/train.py b/examples/hubert/train.py index 66a2e1096d..8f6ea066f0 100644 --- a/examples/hubert/train.py +++ b/examples/hubert/train.py @@ -32,17 +32,9 @@ def run_train(args): save_weights_only=True, verbose=True, ) - train_checkpoint = ModelCheckpoint( - checkpoint_dir, - monitor="Losses/train_loss", - mode="min", - save_top_k=5, - save_weights_only=True, - verbose=True, - ) + callbacks = [ checkpoint, - train_checkpoint, ] trainer = Trainer( default_root_dir=args.exp_dir, @@ -54,6 +46,7 @@ def run_train(args): replace_sampler_ddp=False, gradient_clip_val=args.clip_norm, callbacks=callbacks, + reload_dataloaders_every_n_epochs=1, ) model = HuBERTPreTrainModule( @@ -142,7 +135,7 @@ def _parse_args(): ) parser.add_argument( "--clip-norm", - default=1.0, + default=None, type=float, help="The gradient norm value to clip. (Default: 1.0)", ) From acd8a2058a024aee43479ce705eb534083daeb19 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Wed, 30 Mar 2022 18:05:33 +0100 Subject: [PATCH 3/7] fix arg strings --- examples/hubert/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/hubert/train.py b/examples/hubert/train.py index 8f6ea066f0..0173bc0001 100644 --- a/examples/hubert/train.py +++ b/examples/hubert/train.py @@ -95,7 +95,7 @@ def _parse_args(): default="hubert_pretrain_base", choices=["hubert_pretrain_base", "hubert_pretrain_large", "hubert_pretrain_xlarge"], type=str, - help="The HuBERT model to train.", + help="The HuBERTPretrainModel for pre-training.", ) parser.add_argument( "--exp-dir", @@ -137,7 +137,7 @@ def _parse_args(): "--clip-norm", default=None, type=float, - help="The gradient norm value to clip. (Default: 1.0)", + help="The gradient norm value to clip. (Default: ``None``)", ) parser.add_argument( "--num_nodes", From db0d084a4420089d222ac938ff888eecc5886d1c Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Tue, 3 May 2022 19:02:17 +0100 Subject: [PATCH 4/7] add resume_checkpoint option, add feature_grad_mult for Base model --- examples/hubert/lightning.py | 5 ++++- examples/hubert/train.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/examples/hubert/lightning.py b/examples/hubert/lightning.py index 0b2c30751b..b4aede20cf 100644 --- a/examples/hubert/lightning.py +++ b/examples/hubert/lightning.py @@ -43,6 +43,7 @@ def __init__( self, *, model_name: str, + feature_grad_mult: float, num_classes: int, dataset: str, root_path: str, @@ -58,7 +59,9 @@ def __init__( super().__init__() if model_name == "hubert_pretrain_base": - self.model = torchaudio.models.hubert_pretrain_base(num_classes=num_classes) + self.model = torchaudio.models.hubert_pretrain_base( + feature_grad_mult=feature_grad_mult, num_classes=num_classes + ) elif model_name == "hubert_pretrain_large": self.model = torchaudio.models.hubert_pretrain_large() elif model_name == "hubert_pretrain_xlarge": diff --git a/examples/hubert/train.py b/examples/hubert/train.py index 0173bc0001..0e0eeaff04 100644 --- a/examples/hubert/train.py +++ b/examples/hubert/train.py @@ -7,7 +7,7 @@ import logging import pathlib from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter -from typing import Tuple +from typing import Optional, Tuple from lightning import HuBERTPreTrainModule from pytorch_lightning import Trainer @@ -29,12 +29,20 @@ def run_train(args): monitor="Losses/val_loss", mode="min", save_top_k=5, - save_weights_only=True, + save_weights_only=False, + verbose=True, + ) + train_checkpoint = ModelCheckpoint( + checkpoint_dir, + monitor="Losses/train_loss", + mode="min", + save_top_k=5, + save_weights_only=False, verbose=True, ) - callbacks = [ checkpoint, + train_checkpoint, ] trainer = Trainer( default_root_dir=args.exp_dir, @@ -51,6 +59,7 @@ def run_train(args): model = HuBERTPreTrainModule( model_name=args.model_name, + feature_grad_mult=args.feature_grad_mult, num_classes=args.num_classes, dataset=args.dataset, root_path=args.root_path, @@ -63,7 +72,7 @@ def run_train(args): warmup_updates=args.warmup_updates, max_updates=args.max_updates, ) - trainer.fit(model) + trainer.fit(model, ckpt_path=args.resume_checkpoint) def _parse_args(): @@ -77,12 +86,24 @@ def _parse_args(): required=True, help="Path to the feature and label directories.", ) + parser.add_argument( + "--resume-checkpoint", + type=Optional[pathlib.Path], + default=None, + help="Path to the feature and label directories.", + ) parser.add_argument( "--feature-type", choices=["mfcc", "hubert"], type=str, required=True, ) + parser.add_argument( + "--feature-grad-mult", + default=0.1, + type=float, + help="The factor to multiply the feature extractor gradient. (Default: 0.1)", + ) parser.add_argument( "--num-classes", choices=[100, 500], @@ -140,7 +161,7 @@ def _parse_args(): help="The gradient norm value to clip. (Default: ``None``)", ) parser.add_argument( - "--num_nodes", + "--num-nodes", default=1, type=int, help="Number of nodes to use for training. (Default: 1)", From ee8872ae752fd579d2d3e143b7fc42499bbe8f13 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Thu, 5 May 2022 21:31:51 +0100 Subject: [PATCH 5/7] refactor docstring, add README --- examples/hubert/README.md | 41 ++++++++++++++++++++++++++++++++++++ examples/hubert/lightning.py | 2 +- examples/hubert/train.py | 23 ++++++++++---------- 3 files changed, 54 insertions(+), 12 deletions(-) create mode 100644 examples/hubert/README.md diff --git a/examples/hubert/README.md b/examples/hubert/README.md new file mode 100644 index 0000000000..633b89f20f --- /dev/null +++ b/examples/hubert/README.md @@ -0,0 +1,41 @@ +# HuBERT Pre-training Example + +This directory contains sample implementations of pre-training pipeline for [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447). + +## Usage + +The Base architecture of HuBERT model requires two iterations of pre-training. +### Pre-processing (1st iteration) +[`preprocess.py`](./preprocess.py) generates the file list of training and validation data, trains a KMeans clustering model with either MFCC feature or the transformer layer's output from the pre-trained HuBERT model, then predict the cluster ID for each utterance as the label for masked prediction training. + +Sample SLURM command for the first iteration of pre-preprocessing, which uses MFCC feature to train KMeans model: +``` +srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type mfcc --exp-dir ./exp --num-cluster 100 +``` + +### Pre-training (1st iteration) + +[`train.py`](./train.py) trains a HuBERTPretrainModel using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training. + +The first iteration is trained for 250k steps on 32 GPUs, each GPU has at most 87.5 seconds of audio in a mini-batch. + +Sample SLURM command for the first iteration of pre-training: +``` +srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/mfcc/ --exp-dir ./exp_iter1 --feature-type mfcc --num-class 100 --max-updates 250000 --learning-rate 0.0005 --gpus 8 --num-nodes 4 +``` + +### Pre-processing (2nd iteration) +After the first iteration of pre-training, the intermediate transformer layer's output of the pre-trained HuBERTPretrainModel can be applied to train a new KMeans clustering model. Then the KMeans clustering model can be used to generate new clustering labels for the second iteration of masked prediction training. + +Sample SLURM command for the second iteration of pre-preprocessing. The 6-th transformer layer's output is used as the input feature for training KMeans model. Note that the number of clusters is increased to 500 to improve the performance. +``` +srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type hubert --exp-dir ./exp --layer-index 6 --checkpoint-path ./exp_iter1/checkpoints_librispeech_hubert_pretrain_base/xxx.ckpt --num-cluster 500 +``` + +### Pre-training (2nd iteration) +The second iteration is trained for 400k steps. + +Sample SLURM command for the second iteration of pre-training: +``` +srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/hubert_6/ --exp-dir ./exp_iter2 --feature-type hubert --num-class 500 --max-updates 400000 --learning-rate 0.0005 --gpus 8 --num-nodes 4 +``` diff --git a/examples/hubert/lightning.py b/examples/hubert/lightning.py index b4aede20cf..ba01e9b32e 100644 --- a/examples/hubert/lightning.py +++ b/examples/hubert/lightning.py @@ -2,7 +2,7 @@ import torch import torchaudio -from dataset import BucketizeBatchSampler, DistributedBatchSampler, HuBERTDataSet, CollateFnHubert +from dataset import BucketizeBatchSampler, CollateFnHubert, DistributedBatchSampler, HuBERTDataSet from loss import hubert_loss from pytorch_lightning import LightningModule from torch import Tensor diff --git a/examples/hubert/train.py b/examples/hubert/train.py index 0e0eeaff04..69fab9701c 100644 --- a/examples/hubert/train.py +++ b/examples/hubert/train.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 -"""Train the HuBERTPretrain model by using 960 hours of LibriSpeech training sets. +"""Train the HuBERTPretrainModel by using labels generated by KMeans clustering. Example: python train.py --root-path ./exp/data/mfcc/ --feature-type mfcc --num-classes 100 """ import logging import pathlib -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter from typing import Optional, Tuple from lightning import HuBERTPreTrainModule @@ -90,7 +90,7 @@ def _parse_args(): "--resume-checkpoint", type=Optional[pathlib.Path], default=None, - help="Path to the feature and label directories.", + help="Path to the feature and label directories. (Default: None)", ) parser.add_argument( "--feature-type", @@ -102,7 +102,7 @@ def _parse_args(): "--feature-grad-mult", default=0.1, type=float, - help="The factor to multiply the feature extractor gradient. (Default: 0.1)", + help="The scaling factor to multiply the feature extractor gradient. (Default: 0.1)", ) parser.add_argument( "--num-classes", @@ -116,7 +116,7 @@ def _parse_args(): default="hubert_pretrain_base", choices=["hubert_pretrain_base", "hubert_pretrain_large", "hubert_pretrain_xlarge"], type=str, - help="The HuBERTPretrainModel for pre-training.", + help="The HuBERT model to train. (Default: 'hubert_pretrain_base')", ) parser.add_argument( "--exp-dir", @@ -133,14 +133,15 @@ def _parse_args(): ) parser.add_argument( "--learning-rate", - default=0.003, + default=0.0005, type=float, + help="The peak learning rate. (Default: 0.0005)", ) parser.add_argument( "--betas", default=(0.9, 0.98), type=Tuple, - help=" coefficients for computing running averages of gradient and its square (default: (0.9, 0.98))", + help="The coefficients for computing running averages of gradient and its square (Default: (0.9, 0.98))", ) parser.add_argument( "--eps", @@ -157,14 +158,14 @@ def _parse_args(): parser.add_argument( "--clip-norm", default=None, - type=float, - help="The gradient norm value to clip. (Default: ``None``)", + type=Optional[float], + help="The gradient norm value to clip. (Default: None)", ) parser.add_argument( "--num-nodes", - default=1, + default=4, type=int, - help="Number of nodes to use for training. (Default: 1)", + help="Number of nodes to use for training. (Default: 4)", ) parser.add_argument( "--gpus", From e46304723dd2c94bab8ce6e699d571dad3d2d223 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Fri, 20 May 2022 11:34:20 +0100 Subject: [PATCH 6/7] address comments --- examples/hubert/lightning.py | 4 ++-- examples/hubert/loss/hubert_loss.py | 16 ++++++++-------- examples/hubert/train.py | 4 +++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/hubert/lightning.py b/examples/hubert/lightning.py index ba01e9b32e..fe29ca135b 100644 --- a/examples/hubert/lightning.py +++ b/examples/hubert/lightning.py @@ -83,12 +83,12 @@ def _step(self, batch: Batch, batch_idx, step_type): if batch is None: return None waveforms, labels, audio_lengths = batch - logit_m, logit_u, feature_pen = self.model( + logit_m, logit_u, feature_penalty = self.model( waveforms, labels, audio_lengths, ) - loss = self.loss(logit_m, logit_u, feature_pen) + loss = self.loss(logit_m, logit_u, feature_penalty) self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True) return loss diff --git a/examples/hubert/loss/hubert_loss.py b/examples/hubert/loss/hubert_loss.py index b8cc936674..70f98b6e9d 100644 --- a/examples/hubert/loss/hubert_loss.py +++ b/examples/hubert/loss/hubert_loss.py @@ -8,23 +8,23 @@ def hubert_loss( logit_m: Optional[Tensor], logit_u: Optional[Tensor], - feature_pen: Tensor, + feature_penalty: Tensor, masked_weight: float = 1.0, - nomask_weight: float = 0.0, + unmasked_weight: float = 0.0, feature_weight: float = 10.0, reduction: str = "sum", ) -> Tensor: """Compute the cross-entropy loss on HuBERT masked and non-masked logits. Args: - logit_m (Tensor or None): The masked logit Tensor of dimension `[masked_frames, final_dim]`. - logit_u (Tensor or None): The non-masked logit Tensor of dimension `[nonmasked_frames, final_dim]`. - feature_pen (Tensor): The feature mean value for additional penalty loss. + logit_m (Tensor or None): The masked logit Tensor of dimension `(masked_frames, final_dim)`. + logit_u (Tensor or None): The non-masked logit Tensor of dimension `(unmasked_frames, final_dim)`. + feature_penalty (Tensor): The feature mean value for additional penalty loss. masked_weight (float, optional): The weight for masked cross-entropy loss (Default: ``1.0``). - nomask_weight (float, optional): The weight for non-masked cross-entropy loss (Default: ``0.0``). + unmasked_weight (float, optional): The weight for non-masked cross-entropy loss (Default: ``0.0``). feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``). reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``). """ - loss = feature_pen * feature_weight * logit_m.shape[0] + loss = feature_penalty * feature_weight * logit_m.shape[0] if logit_m is not None: target_m = torch.zeros(logit_m.shape[0], dtype=torch.long, device=logit_m.device) loss_m = F.cross_entropy(logit_m, target_m, reduction=reduction) @@ -32,5 +32,5 @@ def hubert_loss( if logit_u is not None: target_u = torch.zeros(logit_u.shape[0], dtype=torch.long, device=logit_m.device) loss_u = F.cross_entropy(logit_u, target_u, reduction=reduction) - loss += loss_u * nomask_weight + loss += loss_u * unmasked_weight return loss diff --git a/examples/hubert/train.py b/examples/hubert/train.py index 69fab9701c..ea62644c6e 100644 --- a/examples/hubert/train.py +++ b/examples/hubert/train.py @@ -18,7 +18,9 @@ class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): - # https://stackoverflow.com/a/18462760 + # To use ArgumentDefaultsHelpFormatter as the formatter_class and + # RawDescriptionHelpFormatter to add custom formatting to description or epilog. + # Check: https://stackoverflow.com/a/18462760 pass From 20c2519ed31311781c04d54c1065bf248a989b59 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Mon, 23 May 2022 07:31:36 +0100 Subject: [PATCH 7/7] fix lint --- examples/hubert/lightning.py | 7 ++++++- examples/hubert/train.py | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/hubert/lightning.py b/examples/hubert/lightning.py index fe29ca135b..c4d5b66326 100644 --- a/examples/hubert/lightning.py +++ b/examples/hubert/lightning.py @@ -2,7 +2,12 @@ import torch import torchaudio -from dataset import BucketizeBatchSampler, CollateFnHubert, DistributedBatchSampler, HuBERTDataSet +from dataset import ( + BucketizeBatchSampler, + CollateFnHubert, + DistributedBatchSampler, + HuBERTDataSet, +) from loss import hubert_loss from pytorch_lightning import LightningModule from torch import Tensor diff --git a/examples/hubert/train.py b/examples/hubert/train.py index ea62644c6e..a9fefe89a7 100644 --- a/examples/hubert/train.py +++ b/examples/hubert/train.py @@ -6,7 +6,11 @@ import logging import pathlib -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter +from argparse import ( + ArgumentDefaultsHelpFormatter, + ArgumentParser, + RawDescriptionHelpFormatter, +) from typing import Optional, Tuple from lightning import HuBERTPreTrainModule