forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Python CTC decoder API (pytorch#2089)
Summary: Part of pytorch#2072 -- splitting up PR for easier review This PR adds Python decoder API and basic README Pull Request resolved: pytorch#2089 Reviewed By: mthrok Differential Revision: D33299818 Pulled By: carolineechen fbshipit-source-id: 778ec3692331e95258d3734f0d4ab60b6618ddbc
- Loading branch information
1 parent
6169a8d
commit ae414cb
Showing
11 changed files
with
393 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
\data\ | ||
ngram 1=6 | ||
ngram 2=9 | ||
ngram 3=8 | ||
|
||
\1-grams: | ||
-0.8515802 <unk> 0 | ||
0 <s> -0.30103 | ||
-0.8515802 </s> 0 | ||
-0.8515802 foo -0.30103 | ||
-0.44013768 bar -0.30103 | ||
-0.6679358 foobar -0.30103 | ||
|
||
\2-grams: | ||
-0.7091413 foo </s> 0 | ||
-0.6251838 bar </s> 0 | ||
-0.24384303 foobar </s> 0 | ||
-0.6251838 <s> foo -0.30103 | ||
-0.49434766 foo foo -0.30103 | ||
-0.39393726 bar foo -0.30103 | ||
-0.4582359 <s> bar -0.30103 | ||
-0.51359576 foo bar -0.30103 | ||
-0.56213206 <s> foobar -0.30103 | ||
|
||
\3-grams: | ||
-0.45881382 bar foo </s> | ||
-0.43354067 foo bar </s> | ||
-0.105027884 <s> foobar </s> | ||
-0.18033421 <s> foo foo | ||
-0.38702002 bar foo foo | ||
-0.15375455 <s> bar foo | ||
-0.34500393 foo bar foo | ||
-0.18492673 foo foo bar | ||
|
||
\end\ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
foo f o o | | ||
bar b a r | | ||
foobar f o o b a r | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
- | ||
| | ||
f | ||
o | ||
b | ||
a | ||
r |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
def is_ctc_decoder_available(): | ||
try: | ||
import torchaudio.prototype.ctc_decoder # noqa: F401 | ||
|
||
return True | ||
except ImportError: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
from torchaudio_unittest.common_utils import ( | ||
TempDirMixin, | ||
TorchaudioTestCase, | ||
get_asset_path, | ||
skipIfNoCtcDecoder, | ||
) | ||
|
||
|
||
@skipIfNoCtcDecoder | ||
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): | ||
def _get_decoder(self): | ||
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder | ||
|
||
lexicon_file = get_asset_path("decoder/lexicon.txt") | ||
tokens_file = get_asset_path("decoder/tokens.txt") | ||
kenlm_file = get_asset_path("decoder/kenlm.arpa") | ||
|
||
return kenlm_lexicon_decoder( | ||
lexicon=lexicon_file, | ||
tokens=tokens_file, | ||
kenlm=kenlm_file, | ||
) | ||
|
||
def test_construct_decoder(self): | ||
self._get_decoder() | ||
|
||
def test_shape(self): | ||
B, T, N = 4, 15, 10 | ||
|
||
torch.manual_seed(0) | ||
emissions = torch.rand(B, T, N) | ||
|
||
decoder = self._get_decoder() | ||
results = decoder(emissions) | ||
|
||
self.assertEqual(len(results), B) | ||
|
||
def test_index_to_tokens(self): | ||
# decoder tokens: '-' '|' 'f' 'o' 'b' 'a' 'r' | ||
decoder = self._get_decoder() | ||
|
||
idxs = torch.LongTensor((1, 2, 1, 3, 5)) | ||
tokens = decoder.idxs_to_tokens(idxs) | ||
|
||
expected_tokens = ["|", "f", "|", "o", "a"] | ||
self.assertEqual(tokens, expected_tokens) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Flashlight Decoder Binding | ||
CTC Decoder with KenLM and lexicon support based on [flashlight](https://github.com/flashlight/flashlight) decoder implementation | ||
and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53) | ||
Python wrapper | ||
|
||
## Setup | ||
### Build torchaudio with decoder support | ||
``` | ||
BUILD_CTC_DECODER=1 python setup.py develop | ||
``` | ||
|
||
## Usage | ||
```py | ||
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder | ||
decoder = kenlm_lexicon_decoder(args...) | ||
results = decoder(emissions) # dim (B, nbest) of dictionary of "tokens", "score", "words" keys | ||
best_transcripts = [" ".join(results[i][0].words).strip() for i in range(B)] | ||
``` | ||
|
||
## Required Files | ||
- tokens: tokens for which the acoustic model generates probabilities for | ||
- lexicon: mapping between words and its corresponding spelling | ||
- language model: n-gram KenLM model | ||
|
||
## Experiment Results | ||
LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of | ||
BASE configuration. | ||
|
||
| Model | Decoder | dev-other | test-other | beam search params | | ||
| ----------- | ---------- | ----------- | ---------- |-------------------------------------------- | | ||
| BASE_10M | Greedy | 51.6 | 51 | | | ||
| | 4-gram LM | 15.95 | 15.9 | LM weight=3.23, word score=-0.26, beam=1500 | | ||
| BASE_100H | Greedy | 13.6 | 13.3 | | | ||
| | 4-gram LM | 8.5 | 8.8 | LM weight=2.15, word score=-0.52, beam=50 | | ||
| BASE_960H | Greedy | 8.9 | 8.4 | | | ||
| | 4-gram LM | 6.3 | 6.4 | LM weight=1.74, word score=0.52, beam=50 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import torchaudio | ||
|
||
try: | ||
torchaudio._extension._load_lib("libtorchaudio_decoder") | ||
from .ctc_decoder import KenLMLexiconDecoder, kenlm_lexicon_decoder | ||
except ImportError as err: | ||
raise ImportError( | ||
"flashlight decoder bindings are required to use this functionality. " | ||
"Please set BUILD_CTC_DECODER=1 when building from source." | ||
) from err | ||
|
||
|
||
__all__ = [ | ||
"KenLMLexiconDecoder", | ||
"kenlm_lexicon_decoder", | ||
] |
Oops, something went wrong.