diff --git a/torchaudio/csrc/decoder/README.md b/torchaudio/csrc/decoder/README.md index 21cffaf24d..c8277c85c2 100644 --- a/torchaudio/csrc/decoder/README.md +++ b/torchaudio/csrc/decoder/README.md @@ -1,11 +1,11 @@ # Flashlight Decoder Binding CTC Decoder with KenLM and lexicon support based on [flashlight](https://github.com/flashlight/flashlight) decoder implementation -and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53) +and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53) Python wrapper ## Setup ### Build KenLM -- Install KenLM following the instructions [here](https://github.com/kpu/kenlm#compiling) +- Install KenLM in your audio directory following the instructions [here](https://github.com/kpu/kenlm#compiling) - set `KENLM_ROOT` variable to the KenLM installation path ### Build torchaudio with decoder support ``` @@ -17,7 +17,7 @@ BUILD_CTC_DECODER=1 python setup.py develop from torchaudio.prototype import kenlm_lexicon_decoder decoder = kenlm_lexicon_decoder(args...) results = decoder(emissions) # dim (B, nbest) of dictionary of "tokens", "score", "words" keys -best_transcript = " ".join(results[0][0]["words"]).strip() +best_transcripts = [" ".join(results[i][0].words).strip() for i in range(B)] ``` ## Required Files @@ -26,11 +26,11 @@ best_transcript = " ".join(results[0][0]["words"]).strip() - language model: n-gram KenLM model ## Experiment Results -LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of -BASE configuration. +LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of +BASE configuration. | Model | Decoder | dev-other | test-other | beam search params | -| ----------- | ---------- | ----------- | ---------- | ------------------------------------------- | +| ----------- | ---------- | ----------- | ---------- |-------------------------------------------- | | BASE_10M | Greedy | 51.6 | 51 | | | | 4-gram LM | 15.95 | 15.9 | LM weight=3.23, word score=-0.26, beam=1500 | | BASE_100H | Greedy | 13.6 | 13.3 | | diff --git a/torchaudio/csrc/decoder/src/decoder/lm/KenLM.cpp b/torchaudio/csrc/decoder/src/decoder/lm/KenLM.cpp index d0c9794a97..6d7bc52834 100644 --- a/torchaudio/csrc/decoder/src/decoder/lm/KenLM.cpp +++ b/torchaudio/csrc/decoder/src/decoder/lm/KenLM.cpp @@ -9,7 +9,7 @@ #include -#include +#include "kenlm/lm/model.hh" namespace torchaudio { namespace lib { diff --git a/torchaudio/prototype/ctc_decoder/__init__.py b/torchaudio/prototype/ctc_decoder/__init__.py new file mode 100644 index 0000000000..b56ffb3909 --- /dev/null +++ b/torchaudio/prototype/ctc_decoder/__init__.py @@ -0,0 +1,6 @@ +from .ctc_decoder import KenLMLexiconDecoder, kenlm_lexicon_decoder + +__all__ = [ + "KenLMLexiconDecoder", + "kenlm_lexicon_decoder", +] diff --git a/torchaudio/prototype/ctc_decoder.py b/torchaudio/prototype/ctc_decoder/ctc_decoder.py similarity index 89% rename from torchaudio/prototype/ctc_decoder.py rename to torchaudio/prototype/ctc_decoder/ctc_decoder.py index a73b64fd6b..e34064c591 100644 --- a/torchaudio/prototype/ctc_decoder.py +++ b/torchaudio/prototype/ctc_decoder/ctc_decoder.py @@ -1,6 +1,7 @@ import torch import itertools as it from typing import List, Optional, Dict +from collections import namedtuple import torchaudio @@ -26,6 +27,8 @@ __all__ = ["KenLMLexiconDecoder", "kenlm_lexicon_decoder"] +Hypothesis = namedtuple("Hypothesis", ["tokens", "words", "score"]) + class KenLMLexiconDecoder: def __init__( self, @@ -38,9 +41,11 @@ def __init__( blank_token: str, sil_token: str, ) -> None: - """ - Construct a KenLM CTC Lexcion Decoder. + KenLM CTC Decoder with Lexicon constraint. + + Note: + To build the decoder, please use the factory function kenlm_lexicon_decoder. Args: nbest (int): number of best decodings to return @@ -107,13 +112,13 @@ def decode( in time axis of the output Tensor in each batch Returns: - List[List[Dict]]: + List[Hypothesis]: List of sorted best hypotheses for each audio sequence in the batch. - Each hypothesis is dictionary with the following mapping: - "tokens": torch.LongTensor of raw token IDs - "score": hypothesis score - "words": list of decoded words + Each hypothesis is named tuple with the following fields: + tokens: torch.LongTensor of raw token IDs + score: hypothesis score + words: list of decoded words """ B, T, N = emissions.size() if lengths is None: @@ -128,13 +133,11 @@ def decode( nbest_results = results[: self.nbest] hypos.append( [ - { - "tokens": self._get_tokens(result.tokens), - "score": result.score, - "words": [ - self.word_dict.get_entry(x) for x in result.words if x >= 0 - ] - } + Hypothesis( + self._get_tokens(result.tokens), # token ids + list(self.word_dict.get_entry(x) for x in result.words if x >= 0), # words + result.score, # score + ) for result in nbest_results ] )