Skip to content

Commit

Permalink
add evaluation script, add readme
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed May 16, 2022
1 parent cd7b04a commit b06caee
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 76 deletions.
43 changes: 43 additions & 0 deletions examples/hubert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# HuBERT Fine-tuning Example

## Usage

After finishing the pre-training step, the model can be validated by fine-tuning on the supervised subset of Libri-Light dataset with an extra feed-forward layer on top of the transformer layers.

During the whole fine-tuning process, the feature extraction layers are freezed (i.e., no gradients is back propagated to these layers). At the first 10k fine-tuning iterations, the transformer layers are freezed, only the CTC layer is trained. After 10k iterations, the transformer layers are fine-tuned along with the CTC layer.

Sample SLURM command for fine-tuning on `10h` subset of `LibriLightLimited` dataset:
```
srun --gpus-per-node=1 -N 1 --ntasks-per-node=1 --cpus-per-task=10 \
python finetune.py --root-path /root/datasets/ --exp-dir ./exp_finetune \
--checkpoint /exp_iter2/checkpoints_librispeech_hubert_pretrain_base/epoch=361-step=399999.ckpt \
--gpus 1 --debug --warmup-updates 2000 --hold-updates 8000 --decay-updates 10000 --max-updates 20000 --learning-rate 5e-5
```

# Decoding

### Viterbi Decoding
The output of CTC layer contains repeated letters, blank symbol ("-"), and silence symbol ("|"). Viterbi decoding unifies the repeated letters into a single letter, removes the blank symbol, and split the string into a list of words by the silence symbol.

Sample SLURM command for evaluation with Viterbi decoding:
```
srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=109-step\=19999.ckpt --split test-clean
```

### CTC Decoding with language model
torchaudio provides a CTCDecoder feature that is based on FlashLight. The CTCDecoder supports KenLM language model. Use `--use-lm` to enable CTC decoding with KenLM 4-gram language model.

Sample SLURM command for evaluation with KenLM language model:
```
srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=109-step\=19999.ckpt --split test-clean --use-lm --beam-size 1500 --lm-weight 2.46 --word-score -0.59
```

### WER results
The table below contains WER results for fine-tuning HuBERT Base model on `10h` subset of `LibriLightLimited` dataset.

| | WER% (Viterbi)| WER% (KenLM) |
|:-----------------:|--------------:|--------------:|
| dev-clean | 10.7 | 4.4 |
| dev-other | 18.3 | 9.7 |
| test-clean | 10.8 | 4.4 |
| test-other | 18.5 | 10.1 |
125 changes: 86 additions & 39 deletions examples/hubert/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import argparse
import logging
from typing import Optional
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F
import torchaudio
from torchaudio.prototype.ctc_decoder import lexicon_decoder, download_pretrained_files

from torchaudio.prototype.ctc_decoder import CTCDecoder, ctc_decoder, download_pretrained_files
from utils import _get_id2label

logger = logging.getLogger(__name__)


def _load_checkpoint(checkpoint: str):
def _load_checkpoint(checkpoint: str) -> torch.nn.Module:
model = torchaudio.models.hubert_base(aux_num_out=29)
checkpoint = torch.load(checkpoint, map_location="cpu")
state_dict = checkpoint["state_dict"]
Expand All @@ -24,42 +25,80 @@ def _load_checkpoint(checkpoint: str):
return model


def _viterbi_deocde(emission: torch.Tensor, id2token: Dict, blank_idx: int = 0) -> List[str]:
"""Run greedy decoding for ctc outputs.
Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens).
id2token (Dictionary): The dictionary that maps indices of emission's last dimension
to the corresponding tokens.
Returns:
(List of str): The decoding result. List of string in lower case.
"""
hypothesis = F.log_softmax(emission, dim=-1)
hypothesis = hypothesis.argmax(-1).unique_consecutive()
hypothesis = hypothesis[hypothesis != blank_idx]
hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ")
return hypothesis.split()


def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]:
"""Run CTC decoding with KenLM language model.
Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens).
decoder (CTCDecoder): The initialized CTCDecoder.
Returns:
(List of str): The decoding result. List of string in lower case.
"""
hypothesis = decoder(emission)
hypothesis = hypothesis[0][0].words
return hypothesis


def run_inference(args):
# Load the fine-tuned HuBERTPretrainModel from checkpoint.
model = _load_checkpoint(args.checkpoint)
model.eval()

# get decoder files
files = download_pretrained_files("librispeech-4-gram")

decoder = lexicon_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=args.nbest,
beam_size=args.beam_size,
beam_size_token=args.beam_size_token,
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
word_score=args.word_score,
unk_score=args.unk_score,
sil_score=args.sil_score,
log_add=False,
)

dataset = torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url=args.split, download=False)
if args.use_lm:
# get decoder files
files = download_pretrained_files("librispeech-4-gram")
decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=args.nbest,
beam_size=args.beam_size,
beam_size_token=args.beam_size_token,
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
word_score=args.word_score,
unk_score=args.unk_score,
sil_score=args.sil_score,
log_add=False,
)
else:
id2token = _get_id2label()

dataset = torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url=args.split)

total_edit_distance = 0
total_length = 0
for idx, sample in enumerate(dataset):
waveform, _, transcript, _, _, _ = sample
transcript = transcript.strip().lower().strip()
transcript = transcript.strip().lower().strip().replace("\n", "")

with torch.inference_mode():
emission, _ = model(waveform)
results = decoder(emission)
if args.use_lm:
hypothesis = _ctc_decode(emission, decoder)
else:
hypothesis = _viterbi_deocde(emission, id2token)

total_edit_distance += torchaudio.functional.edit_distance(transcript.split(), results[0][0].words)
total_edit_distance += torchaudio.functional.edit_distance(transcript.split(), hypothesis)
total_length += len(transcript.split())

if idx % 100 == 0:
Expand All @@ -75,7 +114,7 @@ def _parse_args():
parser.add_argument(
"--librispeech_path",
type=str,
help="Folder where LibriSpeech is stored",
help="Folder where LibriSpeech dataset is stored.",
)
parser.add_argument(
"--split",
Expand All @@ -87,34 +126,42 @@ def _parse_args():
parser.add_argument(
"--checkpoint",
type=str,
help="The checkpoint of fine-tuned HuBERTPretrainModel",
help="The checkpoint path of fine-tuned HuBERTPretrainModel.",
)
parser.add_argument("--nbest", type=int, default=1, help="number of best hypotheses to return")
parser.add_argument("--use-lm", action="store_true", help="Whether to use language model for decoding.")
parser.add_argument("--nbest", type=int, default=1, help="Number of best hypotheses to return.")
parser.add_argument(
"--beam-size", type=int, default=500, help="beam size for determining number of hypotheses to store"
"--beam-size",
type=int,
default=1500,
help="Beam size for determining number of hypotheses to store. (Default: 1500)",
)
parser.add_argument(
"--beam-size-token",
type=Optional[int],
default=None,
help="number of tokens to consider at each beam search step",
help="Number of tokens to consider at each beam search step. (Default: None)",
)
parser.add_argument(
"--beam-threshold", type=int, default=100, help="Beam threshold for pruning hypotheses. (Default: 100)"
)
parser.add_argument("--beam-threshold", type=int, default=100, help="beam threshold for pruning hypotheses")
parser.add_argument(
"--lm-weight",
type=float,
default=2,
help="languge model weight",
default=2.46,
help="Languge model weight in decoding. (Default: 2.46)",
)
parser.add_argument(
"--word-score",
type=float,
default=-1,
help="word insertion score",
default=-0.59,
help="Word insertion score in decoding. (Default: -0.59)",
)
parser.add_argument(
"--unk_score", type=float, default=float("-inf"), help="Unknown word insertion score. (Default: -inf)"
)
parser.add_argument("--unk_score", type=float, default=float("-inf"), help="unknown word insertion score")
parser.add_argument("--sil_score", type=float, default=0, help="silence insertion score")
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
parser.add_argument("--sil_score", type=float, default=0, help="Silence insertion score. (Default: 0)")
parser.add_argument("--debug", action="store_true", help="Whether to use debug level for logging.")
return parser.parse_args()


Expand Down
Loading

0 comments on commit b06caee

Please sign in to comment.