diff --git a/docs/source/prototype.rst b/docs/source/prototype.rst index c67b5c01301..e2635540c74 100644 --- a/docs/source/prototype.rst +++ b/docs/source/prototype.rst @@ -65,6 +65,18 @@ Hypothesis .. autoclass:: Hypothesis +KenLMLexiconDecoder +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: KenLMLexiconDecoder + + +kenlm_lexicon_decoder +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: kenlm_lexicon_decoder + + References ~~~~~~~~~~ diff --git a/torchaudio/csrc/decoder/README.md b/torchaudio/csrc/decoder/README.md index 21cffaf24d2..c8277c85c20 100644 --- a/torchaudio/csrc/decoder/README.md +++ b/torchaudio/csrc/decoder/README.md @@ -1,11 +1,11 @@ # Flashlight Decoder Binding CTC Decoder with KenLM and lexicon support based on [flashlight](https://github.com/flashlight/flashlight) decoder implementation -and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53) +and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53) Python wrapper ## Setup ### Build KenLM -- Install KenLM following the instructions [here](https://github.com/kpu/kenlm#compiling) +- Install KenLM in your audio directory following the instructions [here](https://github.com/kpu/kenlm#compiling) - set `KENLM_ROOT` variable to the KenLM installation path ### Build torchaudio with decoder support ``` @@ -17,7 +17,7 @@ BUILD_CTC_DECODER=1 python setup.py develop from torchaudio.prototype import kenlm_lexicon_decoder decoder = kenlm_lexicon_decoder(args...) results = decoder(emissions) # dim (B, nbest) of dictionary of "tokens", "score", "words" keys -best_transcript = " ".join(results[0][0]["words"]).strip() +best_transcripts = [" ".join(results[i][0].words).strip() for i in range(B)] ``` ## Required Files @@ -26,11 +26,11 @@ best_transcript = " ".join(results[0][0]["words"]).strip() - language model: n-gram KenLM model ## Experiment Results -LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of -BASE configuration. +LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of +BASE configuration. | Model | Decoder | dev-other | test-other | beam search params | -| ----------- | ---------- | ----------- | ---------- | ------------------------------------------- | +| ----------- | ---------- | ----------- | ---------- |-------------------------------------------- | | BASE_10M | Greedy | 51.6 | 51 | | | | 4-gram LM | 15.95 | 15.9 | LM weight=3.23, word score=-0.26, beam=1500 | | BASE_100H | Greedy | 13.6 | 13.3 | | diff --git a/torchaudio/prototype/__init__.py b/torchaudio/prototype/__init__.py index 0b90a9cd6d0..b0fc4be667a 100644 --- a/torchaudio/prototype/__init__.py +++ b/torchaudio/prototype/__init__.py @@ -1,6 +1,7 @@ from .emformer import Emformer from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model from .rnnt_decoder import Hypothesis, RNNTBeamSearch +from .ctc_decoder import kenlm_lexicon_decoder __all__ = [ @@ -10,4 +11,6 @@ "RNNTBeamSearch", "emformer_rnnt_base", "emformer_rnnt_model", + "KenLMLexiconDecoder", + "kenlm_lexicon_decoder", ] diff --git a/torchaudio/prototype/ctc_decoder.py b/torchaudio/prototype/ctc_decoder.py index a73b64fd6b4..a39bd64aa6a 100644 --- a/torchaudio/prototype/ctc_decoder.py +++ b/torchaudio/prototype/ctc_decoder.py @@ -1,6 +1,7 @@ import torch import itertools as it from typing import List, Optional, Dict +from collections import namedtuple import torchaudio @@ -26,6 +27,8 @@ __all__ = ["KenLMLexiconDecoder", "kenlm_lexicon_decoder"] +Hypothesis = namedtuple("Hypothesis", ["tokens", "words", "score"]) + class KenLMLexiconDecoder: def __init__( self, @@ -38,9 +41,11 @@ def __init__( blank_token: str, sil_token: str, ) -> None: - """ - Construct a KenLM CTC Lexcion Decoder. + KenLM CTC Decoder with Lexicon constraint. + + Note: + To build the decoder, please use the factory function kenlm_lexicon_decoder. Args: nbest (int): number of best decodings to return @@ -107,13 +112,13 @@ def decode( in time axis of the output Tensor in each batch Returns: - List[List[Dict]]: + List[Hypothesis]: List of sorted best hypotheses for each audio sequence in the batch. - Each hypothesis is dictionary with the following mapping: - "tokens": torch.LongTensor of raw token IDs - "score": hypothesis score - "words": list of decoded words + Each hypothesis is named tuple with the following fields: + tokens: torch.LongTensor of raw token IDs + score: hypothesis score + words: list of decoded words """ B, T, N = emissions.size() if lengths is None: @@ -128,13 +133,12 @@ def decode( nbest_results = results[: self.nbest] hypos.append( [ - { - "tokens": self._get_tokens(result.tokens), - "score": result.score, - "words": [ - self.word_dict.get_entry(x) for x in result.words if x >= 0 - ] - } + Hypothesis( + self._get_tokens(result.tokens), + list(self.word_dict.get_entry(x) for x in result.words if x >= 0), + result.score, + + ) for result in nbest_results ] )