Skip to content

Commit 0f17a0a

Browse files
author
Caroline Chen
committed
api/docs modifications
1 parent cf32cb4 commit 0f17a0a

File tree

4 files changed

+39
-20
lines changed

4 files changed

+39
-20
lines changed

docs/source/prototype.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,18 @@ Hypothesis
6565
.. autoclass:: Hypothesis
6666

6767

68+
KenLMLexiconDecoder
69+
~~~~~~~~~~~~~~~~~~~
70+
71+
.. autoclass:: KenLMLexiconDecoder
72+
73+
74+
kenlm_lexicon_decoder
75+
~~~~~~~~~~~~~~~~~~~~~
76+
77+
.. autoclass:: kenlm_lexicon_decoder
78+
79+
6880
References
6981
~~~~~~~~~~
7082

torchaudio/csrc/decoder/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Flashlight Decoder Binding
22
CTC Decoder with KenLM and lexicon support based on [flashlight](https://github.com/flashlight/flashlight) decoder implementation
3-
and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53)
3+
and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53)
44
Python wrapper
55

66
## Setup
77
### Build KenLM
8-
- Install KenLM following the instructions [here](https://github.com/kpu/kenlm#compiling)
8+
- Install KenLM in your audio directory following the instructions [here](https://github.com/kpu/kenlm#compiling)
99
- set `KENLM_ROOT` variable to the KenLM installation path
1010
### Build torchaudio with decoder support
1111
```
@@ -17,7 +17,7 @@ BUILD_CTC_DECODER=1 python setup.py develop
1717
from torchaudio.prototype import kenlm_lexicon_decoder
1818
decoder = kenlm_lexicon_decoder(args...)
1919
results = decoder(emissions) # dim (B, nbest) of dictionary of "tokens", "score", "words" keys
20-
best_transcript = " ".join(results[0][0]["words"]).strip()
20+
best_transcripts = [" ".join(results[i][0].words).strip() for i in range(B)]
2121
```
2222

2323
## Required Files
@@ -26,11 +26,11 @@ best_transcript = " ".join(results[0][0]["words"]).strip()
2626
- language model: n-gram KenLM model
2727

2828
## Experiment Results
29-
LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of
30-
BASE configuration.
29+
LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of
30+
BASE configuration.
3131

3232
| Model | Decoder | dev-other | test-other | beam search params |
33-
| ----------- | ---------- | ----------- | ---------- | ------------------------------------------- |
33+
| ----------- | ---------- | ----------- | ---------- |-------------------------------------------- |
3434
| BASE_10M | Greedy | 51.6 | 51 | |
3535
| | 4-gram LM | 15.95 | 15.9 | LM weight=3.23, word score=-0.26, beam=1500 |
3636
| BASE_100H | Greedy | 13.6 | 13.3 | |

torchaudio/prototype/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .emformer import Emformer
22
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
33
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
4+
from .ctc_decoder import kenlm_lexicon_decoder
45

56

67
__all__ = [
@@ -10,4 +11,6 @@
1011
"RNNTBeamSearch",
1112
"emformer_rnnt_base",
1213
"emformer_rnnt_model",
14+
"KenLMLexiconDecoder",
15+
"kenlm_lexicon_decoder",
1316
]

torchaudio/prototype/ctc_decoder.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import itertools as it
33
from typing import List, Optional, Dict
4+
from collections import namedtuple
45

56
import torchaudio
67

@@ -26,6 +27,8 @@
2627
__all__ = ["KenLMLexiconDecoder", "kenlm_lexicon_decoder"]
2728

2829

30+
Hypothesis = namedtuple("Hypothesis", ["tokens", "words", "score"])
31+
2932
class KenLMLexiconDecoder:
3033
def __init__(
3134
self,
@@ -38,9 +41,11 @@ def __init__(
3841
blank_token: str,
3942
sil_token: str,
4043
) -> None:
41-
4244
"""
43-
Construct a KenLM CTC Lexcion Decoder.
45+
KenLM CTC Decoder with Lexicon constraint.
46+
47+
Note:
48+
To build the decoder, please use the factory function kenlm_lexicon_decoder.
4449
4550
Args:
4651
nbest (int): number of best decodings to return
@@ -107,13 +112,13 @@ def decode(
107112
in time axis of the output Tensor in each batch
108113
109114
Returns:
110-
List[List[Dict]]:
115+
List[Hypothesis]:
111116
List of sorted best hypotheses for each audio sequence in the batch.
112117
113-
Each hypothesis is dictionary with the following mapping:
114-
"tokens": torch.LongTensor of raw token IDs
115-
"score": hypothesis score
116-
"words": list of decoded words
118+
Each hypothesis is named tuple with the following fields:
119+
tokens: torch.LongTensor of raw token IDs
120+
score: hypothesis score
121+
words: list of decoded words
117122
"""
118123
B, T, N = emissions.size()
119124
if lengths is None:
@@ -128,13 +133,12 @@ def decode(
128133
nbest_results = results[: self.nbest]
129134
hypos.append(
130135
[
131-
{
132-
"tokens": self._get_tokens(result.tokens),
133-
"score": result.score,
134-
"words": [
135-
self.word_dict.get_entry(x) for x in result.words if x >= 0
136-
]
137-
}
136+
Hypothesis(
137+
self._get_tokens(result.tokens),
138+
list(self.word_dict.get_entry(x) for x in result.words if x >= 0),
139+
result.score,
140+
141+
)
138142
for result in nbest_results
139143
]
140144
)

0 commit comments

Comments
 (0)