Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recipe for HuBERT model pre-training #2198

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/hubert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# HuBERT Pre-training Example

This directory contains sample implementations of pre-training pipeline for [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447).

## Usage

The Base architecture of HuBERT model requires two iterations of pre-training.
### Pre-processing (1st iteration)
[`preprocess.py`](./preprocess.py) generates the file list of training and validation data, trains a KMeans clustering model with either MFCC feature or the transformer layer's output from the pre-trained HuBERT model, then predict the cluster ID for each utterance as the label for masked prediction training.

Sample SLURM command for the first iteration of pre-preprocessing, which uses MFCC feature to train KMeans model:
```
srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type mfcc --exp-dir ./exp --num-cluster 100
```

### Pre-training (1st iteration)

[`train.py`](./train.py) trains a HuBERTPretrainModel using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training.

The first iteration is trained for 250k steps on 32 GPUs, each GPU has at most 87.5 seconds of audio in a mini-batch.

Sample SLURM command for the first iteration of pre-training:
```
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/mfcc/ --exp-dir ./exp_iter1 --feature-type mfcc --num-class 100 --max-updates 250000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
```

### Pre-processing (2nd iteration)
After the first iteration of pre-training, the intermediate transformer layer's output of the pre-trained HuBERTPretrainModel can be applied to train a new KMeans clustering model. Then the KMeans clustering model can be used to generate new clustering labels for the second iteration of masked prediction training.

Sample SLURM command for the second iteration of pre-preprocessing. The 6-th transformer layer's output is used as the input feature for training KMeans model. Note that the number of clusters is increased to 500 to improve the performance.
```
srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type hubert --exp-dir ./exp --layer-index 6 --checkpoint-path ./exp_iter1/checkpoints_librispeech_hubert_pretrain_base/xxx.ckpt --num-cluster 500
```

### Pre-training (2nd iteration)
The second iteration is trained for 400k steps.

Sample SLURM command for the second iteration of pre-training:
```
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/hubert_6/ --exp-dir ./exp_iter2 --feature-type hubert --num-class 500 --max-updates 400000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
```
153 changes: 153 additions & 0 deletions examples/hubert/lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import Tuple

import torch
import torchaudio
from dataset import (
BucketizeBatchSampler,
CollateFnHubert,
DistributedBatchSampler,
HuBERTDataSet,
)
from loss import hubert_loss
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader


Batch = Tuple[Tensor, Tensor, Tensor]


class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler):
"""Linear learning rate scheduler with warm up."""

def __init__(
self,
optimizer: Optimizer,
warmup_updates: int,
max_updates: int,
last_epoch: int = -1,
verbose: bool = False,
):
self.warmup_updates = warmup_updates
self.max_updates = max_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)

def get_lr(self):
if self._step_count <= self.warmup_updates:
return [self._step_count / self.warmup_updates * base_lr for base_lr in self.base_lrs]
elif self._step_count >= self.max_updates:
return [0.0 for _ in self.base_lrs]
else:
pct_remaining = (self.max_updates - self._step_count) / (self.max_updates - self.warmup_updates)
return [base_lr * pct_remaining for base_lr in self.base_lrs]


class HuBERTPreTrainModule(LightningModule):
def __init__(
self,
*,
model_name: str,
feature_grad_mult: float,
num_classes: int,
dataset: str,
root_path: str,
feature_type: str,
seconds_per_batch: float,
learning_rate: float,
betas: Tuple[float, float],
eps: float,
weight_decay: float,
warmup_updates: int,
max_updates: int,
):
super().__init__()

if model_name == "hubert_pretrain_base":
self.model = torchaudio.models.hubert_pretrain_base(
feature_grad_mult=feature_grad_mult, num_classes=num_classes
)
elif model_name == "hubert_pretrain_large":
self.model = torchaudio.models.hubert_pretrain_large()
elif model_name == "hubert_pretrain_xlarge":
self.model = torchaudio.models.hubert_pretrain_xlarge()
else:
raise ValueError(f"Unsupported model name: {model_name}")

self.loss = hubert_loss
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay
)
self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates)
self.dataset = dataset
self.root_path = root_path
self.feature_type = feature_type
self.seconds_per_batch = seconds_per_batch

def _step(self, batch: Batch, batch_idx, step_type):
if batch is None:
return None
waveforms, labels, audio_lengths = batch
logit_m, logit_u, feature_penalty = self.model(
waveforms,
labels,
audio_lengths,
)
loss = self.loss(logit_m, logit_u, feature_penalty)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss

def configure_optimizers(self):
return (
[self.optimizer],
[
{
"scheduler": self.lr_scheduler,
"interval": "step",
},
],
)

def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train")

def validation_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "val")

def train_dataloader(self):
dataset = HuBERTDataSet(self.root_path, self.dataset, "train")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=10000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
)
sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.current_epoch)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=10,
)
return dataloader

def val_dataloader(self):
dataset = HuBERTDataSet(self.root_path, self.dataset, "valid")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=10,
)
return dataloader
5 changes: 5 additions & 0 deletions examples/hubert/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .hubert_loss import hubert_loss

__all__ = [
"hubert_loss",
]
36 changes: 36 additions & 0 deletions examples/hubert/loss/hubert_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor


def hubert_loss(
logit_m: Optional[Tensor],
logit_u: Optional[Tensor],
feature_penalty: Tensor,
masked_weight: float = 1.0,
unmasked_weight: float = 0.0,
feature_weight: float = 10.0,
reduction: str = "sum",
) -> Tensor:
"""Compute the cross-entropy loss on HuBERT masked and non-masked logits.
Args:
logit_m (Tensor or None): The masked logit Tensor of dimension `(masked_frames, final_dim)`.
logit_u (Tensor or None): The non-masked logit Tensor of dimension `(unmasked_frames, final_dim)`.
feature_penalty (Tensor): The feature mean value for additional penalty loss.
masked_weight (float, optional): The weight for masked cross-entropy loss (Default: ``1.0``).
unmasked_weight (float, optional): The weight for non-masked cross-entropy loss (Default: ``0.0``).
feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``).
reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``).
"""
loss = feature_penalty * feature_weight * logit_m.shape[0]
if logit_m is not None:
target_m = torch.zeros(logit_m.shape[0], dtype=torch.long, device=logit_m.device)
loss_m = F.cross_entropy(logit_m, target_m, reduction=reduction)
loss += loss_m * masked_weight
if logit_u is not None:
target_u = torch.zeros(logit_u.shape[0], dtype=torch.long, device=logit_m.device)
loss_u = F.cross_entropy(logit_u, target_u, reduction=reduction)
loss += loss_u * unmasked_weight
return loss
Loading