Skip to content

Commit

Permalink
Add Python CTC decoder API (pytorch#2089)
Browse files Browse the repository at this point in the history
Summary:
Part of pytorch#2072 -- splitting up PR for easier review

This PR adds Python decoder API and basic README

Pull Request resolved: pytorch#2089

Reviewed By: mthrok

Differential Revision: D33299818

Pulled By: carolineechen

fbshipit-source-id: 778ec3692331e95258d3734f0d4ab60b6618ddbc
  • Loading branch information
Caroline Chen authored and xiaohui-zhang committed May 4, 2022
1 parent 6169a8d commit ae414cb
Show file tree
Hide file tree
Showing 11 changed files with 393 additions and 0 deletions.
20 changes: 20 additions & 0 deletions docs/source/prototype.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,26 @@ Hypothesis
.. autoclass:: Hypothesis


KenLMLexiconDecoder
~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchaudio.prototype.ctc_decoder

.. autoclass:: KenLMLexiconDecoder

.. automethod:: __call__

.. automethod:: idxs_to_tokens


kenlm_lexicon_decoder
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchaudio.prototype.ctc_decoder

.. autoclass:: kenlm_lexicon_decoder


References
~~~~~~~~~~

Expand Down
35 changes: 35 additions & 0 deletions test/torchaudio_unittest/assets/decoder/kenlm.arpa
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
\data\
ngram 1=6
ngram 2=9
ngram 3=8

\1-grams:
-0.8515802 <unk> 0
0 <s> -0.30103
-0.8515802 </s> 0
-0.8515802 foo -0.30103
-0.44013768 bar -0.30103
-0.6679358 foobar -0.30103

\2-grams:
-0.7091413 foo </s> 0
-0.6251838 bar </s> 0
-0.24384303 foobar </s> 0
-0.6251838 <s> foo -0.30103
-0.49434766 foo foo -0.30103
-0.39393726 bar foo -0.30103
-0.4582359 <s> bar -0.30103
-0.51359576 foo bar -0.30103
-0.56213206 <s> foobar -0.30103

\3-grams:
-0.45881382 bar foo </s>
-0.43354067 foo bar </s>
-0.105027884 <s> foobar </s>
-0.18033421 <s> foo foo
-0.38702002 bar foo foo
-0.15375455 <s> bar foo
-0.34500393 foo bar foo
-0.18492673 foo foo bar

\end\
3 changes: 3 additions & 0 deletions test/torchaudio_unittest/assets/decoder/lexicon.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
foo f o o |
bar b a r |
foobar f o o b a r |
7 changes: 7 additions & 0 deletions test/torchaudio_unittest/assets/decoder/tokens.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-
|
f
o
b
a
r
2 changes: 2 additions & 0 deletions test/torchaudio_unittest/common_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TestBaseMixin,
PytorchTestCase,
TorchaudioTestCase,
skipIfNoCtcDecoder,
skipIfNoCuda,
skipIfNoExec,
skipIfNoModule,
Expand Down Expand Up @@ -42,6 +43,7 @@
"TestBaseMixin",
"PytorchTestCase",
"TorchaudioTestCase",
"skipIfNoCtcDecoder",
"skipIfNoCuda",
"skipIfNoExec",
"skipIfNoModule",
Expand Down
2 changes: 2 additions & 0 deletions test/torchaudio_unittest/common_utils/case_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchaudio._internal.module_utils import is_module_available, is_sox_available, is_kaldi_available

from .backend_utils import set_audio_backend
from .ctc_decoder_utils import is_ctc_decoder_available


class TempDirMixin:
Expand Down Expand Up @@ -115,6 +116,7 @@ def skipIfNoCuda(test_item):

skipIfNoSox = unittest.skipIf(not is_sox_available(), reason="Sox not available")
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason="Kaldi not available")
skipIfNoCtcDecoder = unittest.skipIf(not is_ctc_decoder_available(), reason="CTC decoder not available")
skipIfRocm = unittest.skipIf(
os.getenv("TORCHAUDIO_TEST_WITH_ROCM", "0") == "1", reason="test doesn't currently work on the ROCm stack"
)
Expand Down
7 changes: 7 additions & 0 deletions test/torchaudio_unittest/common_utils/ctc_decoder_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def is_ctc_decoder_available():
try:
import torchaudio.prototype.ctc_decoder # noqa: F401

return True
except ImportError:
return False
47 changes: 47 additions & 0 deletions test/torchaudio_unittest/prototype/ctc_decoder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_asset_path,
skipIfNoCtcDecoder,
)


@skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self):
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder

lexicon_file = get_asset_path("decoder/lexicon.txt")
tokens_file = get_asset_path("decoder/tokens.txt")
kenlm_file = get_asset_path("decoder/kenlm.arpa")

return kenlm_lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens_file,
kenlm=kenlm_file,
)

def test_construct_decoder(self):
self._get_decoder()

def test_shape(self):
B, T, N = 4, 15, 10

torch.manual_seed(0)
emissions = torch.rand(B, T, N)

decoder = self._get_decoder()
results = decoder(emissions)

self.assertEqual(len(results), B)

def test_index_to_tokens(self):
# decoder tokens: '-' '|' 'f' 'o' 'b' 'a' 'r'
decoder = self._get_decoder()

idxs = torch.LongTensor((1, 2, 1, 3, 5))
tokens = decoder.idxs_to_tokens(idxs)

expected_tokens = ["|", "f", "|", "o", "a"]
self.assertEqual(tokens, expected_tokens)
36 changes: 36 additions & 0 deletions torchaudio/csrc/decoder/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Flashlight Decoder Binding
CTC Decoder with KenLM and lexicon support based on [flashlight](https://github.com/flashlight/flashlight) decoder implementation
and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53)
Python wrapper

## Setup
### Build torchaudio with decoder support
```
BUILD_CTC_DECODER=1 python setup.py develop
```

## Usage
```py
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder
decoder = kenlm_lexicon_decoder(args...)
results = decoder(emissions) # dim (B, nbest) of dictionary of "tokens", "score", "words" keys
best_transcripts = [" ".join(results[i][0].words).strip() for i in range(B)]
```

## Required Files
- tokens: tokens for which the acoustic model generates probabilities for
- lexicon: mapping between words and its corresponding spelling
- language model: n-gram KenLM model

## Experiment Results
LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of
BASE configuration.

| Model | Decoder | dev-other | test-other | beam search params |
| ----------- | ---------- | ----------- | ---------- |-------------------------------------------- |
| BASE_10M | Greedy | 51.6 | 51 | |
| | 4-gram LM | 15.95 | 15.9 | LM weight=3.23, word score=-0.26, beam=1500 |
| BASE_100H | Greedy | 13.6 | 13.3 | |
| | 4-gram LM | 8.5 | 8.8 | LM weight=2.15, word score=-0.52, beam=50 |
| BASE_960H | Greedy | 8.9 | 8.4 | |
| | 4-gram LM | 6.3 | 6.4 | LM weight=1.74, word score=0.52, beam=50 |
16 changes: 16 additions & 0 deletions torchaudio/prototype/ctc_decoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torchaudio

try:
torchaudio._extension._load_lib("libtorchaudio_decoder")
from .ctc_decoder import KenLMLexiconDecoder, kenlm_lexicon_decoder
except ImportError as err:
raise ImportError(
"flashlight decoder bindings are required to use this functionality. "
"Please set BUILD_CTC_DECODER=1 when building from source."
) from err


__all__ = [
"KenLMLexiconDecoder",
"kenlm_lexicon_decoder",
]
Loading

0 comments on commit ae414cb

Please sign in to comment.