Skip to content

Commit

Permalink
Add HuBERT fine-tuning recipe (#2352)
Browse files Browse the repository at this point in the history
Summary:
The PR contains the CTC fine-tuning recipe of HuBERT Base model.
The files include:
- lightning module
- training script
- README and the result table
- evaluation scripts

Pull Request resolved: #2352

Reviewed By: hwangjeff

Differential Revision: D36915712

Pulled By: nateanl

fbshipit-source-id: 0fbd0075ecdf7b8ef623e0a44cffd1ba78218cba
  • Loading branch information
nateanl authored and facebook-github-bot committed Jun 7, 2022
1 parent 4c19e2c commit 59bd83b
Show file tree
Hide file tree
Showing 9 changed files with 837 additions and 14 deletions.
50 changes: 46 additions & 4 deletions examples/hubert/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# HuBERT Pre-training Example
# HuBERT Pre-training and Fine-tuning Examples

This directory contains sample implementations of pre-training pipeline for [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447).

## Usage
## Pre-training Usage

The Base architecture of HuBERT model requires two iterations of pre-training.
### Pre-processing (1st iteration)
Expand All @@ -21,7 +21,7 @@ The first iteration is trained for 250k steps on 32 GPUs, each GPU has at most 8

Sample SLURM command for the first iteration of pre-training:
```
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/mfcc/ --exp-dir ./exp_iter1 --feature-type mfcc --num-class 100 --max-updates 250000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --dataset-path ./exp/data/mfcc/ --exp-dir ./exp_iter1 --feature-type mfcc --num-class 100 --max-updates 250000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
```

### Pre-processing (2nd iteration)
Expand All @@ -37,5 +37,47 @@ The second iteration is trained for 400k steps.

Sample SLURM command for the second iteration of pre-training:
```
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/hubert_6/ --exp-dir ./exp_iter2 --feature-type hubert --num-class 500 --max-updates 400000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --dataset-path ./exp/data/hubert_6/ --exp-dir ./exp_iter2 --feature-type hubert --num-class 500 --max-updates 400000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
```

## Fine-tuning Usage

After finishing the pre-training step, the model can be validated by fine-tuning on the `LibriLightLimited` dataset (the supervised subset of [Libri-Light](https://github.com/facebookresearch/libri-light) dataset) with an extra feed-forward layer on top of the transformer layers.

During the whole fine-tuning process, the feature extraction layers are frozen (i.e., no gradients is back propagated to these layers). For the first 10k fine-tuning iterations, the transformer layers are frozen and only the CTC layer is trained. After 10k iterations, the transformer layers are fine-tuned along with the CTC layer.

Sample SLURM command for fine-tuning on `10h` subset of `LibriLightLimited` dataset:
```
srun --gpus-per-node=1 -N 1 --ntasks-per-node=1 --cpus-per-task=10 \
python finetune.py --dataset-path /root/datasets/ --exp-dir ./exp_finetune \
--checkpoint /exp_iter2/checkpoints_librispeech_hubert_pretrain_base/epoch=361-step=399999.ckpt \
--gpus 1 --debug --warmup-updates 2000 --hold-updates 8000 --decay-updates 10000 --max-updates 20000 --learning-rate 5e-5
```

# Decoding

### Viterbi Decoding
The output of CTC layer contains repeated letters, blank symbol ("-"), and silence symbol ("|"). Viterbi decoding unifies the repeated letters into a single letter, removes the blank symbol, and splits the string into a list of words by the silence symbol.

Sample SLURM command for evaluation with Viterbi decoding:
```
srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=109-step\=19999.ckpt --split test-clean
```

### CTC Decoding with language model
torchaudio provides a CTCDecoder feature that is based on [Flashlight](https://github.com/flashlight/flashlight). The decoder supports KenLM language model. Use `--use-lm` to enable CTC decoding with KenLM 4-gram language model.

Sample SLURM command for evaluation with KenLM language model:
```
srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=109-step\=19999.ckpt --split test-clean --use-lm --beam-size 1500 --lm-weight 2.46 --word-score -0.59
```

### WER results
The table below contains WER results for fine-tuning HuBERT Base model on `10h` subset of `LibriLightLimited` dataset.

| | WER% (Viterbi)| WER% (KenLM) |
|:-----------------:|--------------:|--------------:|
| dev-clean | 10.7 | 4.4 |
| dev-other | 18.3 | 9.7 |
| test-clean | 10.8 | 4.4 |
| test-other | 18.5 | 10.1 |
12 changes: 11 additions & 1 deletion examples/hubert/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from .hubert_dataset import BucketizeBatchSampler, CollateFnHubert, HuBERTDataSet
from .hubert_dataset import (
_get_lengths_librilightlimited,
_get_lengths_librispeech,
BucketizeBatchSampler,
CollateFnHubert,
CollateFnLibriLightLimited,
HuBERTDataSet,
)


__all__ = [
"_get_lengths_librilightlimited",
"_get_lengths_librispeech",
"BucketizeBatchSampler",
"CollateFnHubert",
"CollateFnLibriLightLimited",
"HuBERTDataSet",
]
71 changes: 71 additions & 0 deletions examples/hubert/dataset/hubert_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import math
import os

import sys
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union

Expand All @@ -9,6 +12,9 @@
from torch import Tensor
from torch.utils.data import BatchSampler, Dataset, DistributedSampler

sys.path.append("..")
from utils import _get_label2id


class BucketizeBatchSampler(BatchSampler):
"""Buketized BatchSampler for sequential data with different lengths to reduce number of paddings.
Expand Down Expand Up @@ -407,3 +413,68 @@ def __call__(self, batch: List[Tuple[Tensor, Tensor, int]]) -> Tuple[Tensor, Ten
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
lengths = torch.tensor(lengths)
return waveforms, labels, lengths


def _get_lengths_librilightlimited(files: List[str]) -> List[int]:
lengths = []
for file_path, fileid in files:
speaker_id, chapter_id, utterance_id = fileid.split("-")
# Load audio
file_audio = f"{speaker_id}-{chapter_id}-{utterance_id}.flac"
file_audio = os.path.join(file_path, speaker_id, chapter_id, file_audio)
length = torchaudio.info(file_audio).num_frames
lengths.append(length)
return lengths


def _get_lengths_librispeech(files: List[str], path: str, ext_audio: str) -> List[int]:
lengths = []
for file_path in files:
speaker_id, chapter_id, utterance_id = file_path.split("-")
fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
file_audio = fileid_audio + ext_audio
file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
length = torchaudio.info(file_audio).num_frames
lengths.append(length)
return lengths


class CollateFnLibriLightLimited:
"""The collate class for LibriSpeech or LibriLightLimited dataset."""

def __call__(self, batch: List[Tuple[Tensor, int, str, int, int, int]]) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
batch (List(Tuple(Tensor, int, str, int, int, int))):
The list of tuples that contains
waveform, sample_rate, transcript, speaker_id, chapter_id, and utterance_id.
Returns:
(Tuple(Tensor, Tensor, Tensor, Tensor)):
The Tensor of waveforms with dimensions `(batch, time)`.
The Tensor of labels with dimensions `(batch, seq)`.
The Tensor of audio lengths with dimensions `(batch,)`.
The Tensor of length lengths with dimensions `(batch,)`.
"""
audio_sizes = [sample[0].shape[1] for sample in batch]
audio_size = max(audio_sizes)
waveforms, labels, audio_lengths, label_lengths = [], [], [], []
label2id = _get_label2id()
for sample in batch:
waveform, transcript = sample[0], sample[2]
label = torch.tensor([label2id[e] for e in transcript.replace(" ", "|").upper()])
audio_length = waveform.size(1)
label_length = label.size(0)
waveforms.append(waveform)
audio_lengths.append(audio_length)
label_lengths.append(label_length)
labels.append(label)

data = torch.zeros(len(batch), audio_size)
for i in range(len(waveforms)):
data[i][0 : waveforms[i].shape[1]] = waveforms[i]
audio_lengths = torch.tensor(audio_lengths)
label_lengths = torch.tensor(label_lengths)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-1)
return data, labels.int(), audio_lengths.int(), label_lengths.int()
181 changes: 181 additions & 0 deletions examples/hubert/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import argparse
import logging
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F
import torchaudio
from torchaudio.models.decoder import ctc_decoder, CTCDecoder, download_pretrained_files
from utils import _get_id2label

logger = logging.getLogger(__name__)


def _load_checkpoint(checkpoint: str) -> torch.nn.Module:
model = torchaudio.models.hubert_base(aux_num_out=29)
checkpoint = torch.load(checkpoint, map_location="cpu")
state_dict = checkpoint["state_dict"]
new_state_dict = {}
for k in state_dict:
if "model.wav2vec2" in k:
new_state_dict[k.replace("model.wav2vec2.", "")] = state_dict[k]
elif "aux" in k:
new_state_dict[k] = state_dict[k]
model.load_state_dict(new_state_dict)
return model


def _viterbi_decode(emission: torch.Tensor, id2token: Dict, blank_idx: int = 0) -> List[str]:
"""Run greedy decoding for ctc outputs.
Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens).
id2token (Dictionary): The dictionary that maps indices of emission's last dimension
to the corresponding tokens.
Returns:
(List of str): The decoding result. List of string in lower case.
"""
hypothesis = F.log_softmax(emission, dim=-1)
hypothesis = hypothesis.argmax(-1).unique_consecutive()
hypothesis = hypothesis[hypothesis != blank_idx]
hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ")
return hypothesis.split()


def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]:
"""Run CTC decoding with a KenLM language model.
Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens).
decoder (CTCDecoder): The initialized CTCDecoder.
Returns:
(List of str): The decoding result. List of string in lower case.
"""
hypothesis = decoder(emission)
hypothesis = hypothesis[0][0].words
return hypothesis


def run_inference(args):
# Load the fine-tuned HuBERTPretrainModel from checkpoint.
model = _load_checkpoint(args.checkpoint)
model.eval()

if args.use_lm:
# get decoder files
files = download_pretrained_files("librispeech-4-gram")
decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=args.nbest,
beam_size=args.beam_size,
beam_size_token=args.beam_size_token,
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
word_score=args.word_score,
unk_score=args.unk_score,
sil_score=args.sil_score,
log_add=False,
)
else:
id2token = _get_id2label()

dataset = torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url=args.split)

total_edit_distance = 0
total_length = 0
for idx, sample in enumerate(dataset):
waveform, _, transcript, _, _, _ = sample
transcript = transcript.strip().lower().strip().replace("\n", "")

with torch.inference_mode():
emission, _ = model(waveform)
if args.use_lm:
hypothesis = _ctc_decode(emission, decoder)
else:
hypothesis = _viterbi_decode(emission, id2token)

total_edit_distance += torchaudio.functional.edit_distance(transcript.split(), hypothesis)
total_length += len(transcript.split())

if idx % 100 == 0:
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.info(f"Final WER: {total_edit_distance / total_length}")


def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--librispeech-path",
type=str,
help="Folder where LibriSpeech dataset is stored.",
)
parser.add_argument(
"--split",
type=str,
choices=["dev-clean", "dev-other", "test-clean", "test-other"],
help="LibriSpeech dataset split. (Default: 'test-clean')",
default="test-clean",
)
parser.add_argument(
"--checkpoint",
type=str,
help="The checkpoint path of fine-tuned HuBERTPretrainModel.",
)
parser.add_argument("--use-lm", action="store_true", help="Whether to use language model for decoding.")
parser.add_argument("--nbest", type=int, default=1, help="Number of best hypotheses to return.")
parser.add_argument(
"--beam-size",
type=int,
default=1500,
help="Beam size for determining number of hypotheses to store. (Default: 1500)",
)
parser.add_argument(
"--beam-size-token",
type=Optional[int],
default=None,
help="Number of tokens to consider at each beam search step. (Default: None)",
)
parser.add_argument(
"--beam-threshold", type=int, default=100, help="Beam threshold for pruning hypotheses. (Default: 100)"
)
parser.add_argument(
"--lm-weight",
type=float,
default=2.46,
help="Languge model weight in decoding. (Default: 2.46)",
)
parser.add_argument(
"--word-score",
type=float,
default=-0.59,
help="Word insertion score in decoding. (Default: -0.59)",
)
parser.add_argument(
"--unk-score", type=float, default=float("-inf"), help="Unknown word insertion score. (Default: -inf)"
)
parser.add_argument("--sil-score", type=float, default=0, help="Silence insertion score. (Default: 0)")
parser.add_argument("--debug", action="store_true", help="Whether to use debug level for logging.")
return parser.parse_args()


def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")


def _main():
args = _parse_args()
_init_logger(args.debug)
run_inference(args)


if __name__ == "__main__":
_main()
Loading

0 comments on commit 59bd83b

Please sign in to comment.