Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Dec 14, 2021
1 parent 894e80c commit 9981f77
Show file tree
Hide file tree
Showing 30 changed files with 118 additions and 123 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ endif()
option(BUILD_SOX "Build libsox statically" ON)
option(BUILD_KALDI "Build kaldi statically" ON)
option(BUILD_RNNT "Enable RNN transducer" ON)
option(BUILD_FL_DECODER "Build Flashlight decoder" OFF)
option(BUILD_CTC_DECODER "Build Flashlight decoder" OFF)
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
option(USE_CUDA "Enable CUDA support" OFF)
option(USE_ROCM "Enable ROCM support" OFF)
Expand Down
2 changes: 1 addition & 1 deletion examples/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SET(BUILD_SOX ON CACHE BOOL "Build libsox into libtorchaudio")

SET(BUILD_KALDI OFF CACHE BOOL "Build Kaldi into libtorchaudio")
SET(BUILD_RNNT ON CACHE BOOL "Build RNN transducer into libtorchaudio")
SET(BUILD_FL_DECODER OFF CACHE BOOL "Build Flashlight decoder into libtorchaudio")
SET(BUILD_CTC_DECODER OFF CACHE BOOL "Build Flashlight decoder into libtorchaudio")
SET(BUILD_TORCHAUDIO_PYTHON_EXTENSION OFF CACHE BOOL "Build Python binding")

find_package(Torch REQUIRED)
Expand Down
2 changes: 1 addition & 1 deletion examples/libtorchaudio/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ cmake -GNinja \
-DBUILD_SOX=ON \
-DBUILD_KALDI=OFF \
-DBUILD_RNNT=ON \
-DBUILD_FL_DECODER=OFF \
-DBUILD_CTC_DECODER=OFF \
..
cmake --build .
```
Expand Down
2 changes: 1 addition & 1 deletion third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ endif()
################################################################################
# KenLM
################################################################################
if (BUILD_FL_DECODER)
if (BUILD_CTC_DECODER)
find_package(kenlm)
if (NOT kenlm_FOUND)
message(FATAL_ERROR "KenLM not found - Please install KenLM and set KENLM_ROOT.")
Expand Down
6 changes: 3 additions & 3 deletions tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_build(var, default=False):
_BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX", True)
_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True)
_BUILD_RNNT = _get_build("BUILD_RNNT", True)
_BUILD_FL_DECODER = _get_build("BUILD_FL_DECODER", False)
_BUILD_CTC_DECODER = _get_build("BUILD_CTC_DECODER", False)
_USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None)
_USE_OPENMP = _get_build("USE_OPENMP", True) and \
Expand All @@ -51,7 +51,7 @@ def get_ext_modules():
Extension(name='torchaudio._torchaudio', sources=[]),
]

if _BUILD_FL_DECODER:
if _BUILD_CTC_DECODER:
modules.extend([
Extension(name='torchaudio.lib.libtorchaudio_decoder', sources=[]),
Extension(name='torchaudio._torchaudio_decoder', sources=[]),
Expand Down Expand Up @@ -97,7 +97,7 @@ def build_extension(self, ext):
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
f"-DBUILD_RNNT:BOOL={'ON' if _BUILD_RNNT else 'OFF'}",
f"-DBUILD_FL_DECODER:BOOL={'ON' if _BUILD_FL_DECODER else 'OFF'}",
f"-DBUILD_CTC_DECODER:BOOL={'ON' if _BUILD_CTC_DECODER else 'OFF'}",
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}",
f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}",
Expand Down
6 changes: 3 additions & 3 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ define_library(
################################################################################
# libtorchaudio_decoder.so
################################################################################
if (BUILD_FL_DECODER)
if (BUILD_CTC_DECODER)
set(
LIBTORCHAUDIO_DECODER_SOURCES
decoder/src/decoder/LexiconDecoder.cpp
Expand All @@ -144,7 +144,7 @@ if (BUILD_FL_DECODER)
)
set(
LIBTORCHAUDIO_DECODER_DEFINITIONS
BUILD_FL_DECODER
BUILD_CTC_DECODER
KENLM_MAX_ORDER=${KENLM_MAX_ORDER}
)
set(
Expand Down Expand Up @@ -226,7 +226,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
libtorchaudio
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS}"
)
if(BUILD_FL_DECODER)
if(BUILD_CTC_DECODER)
set(
DECODER_EXTENSION_SOURCES
decoder/bindings/pybind.cpp
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/csrc/decoder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Python wrapper
- set `KENLM_ROOT` variable to the KenLM installation path
### Build torchaudio with decoder support
```
BUILD_FL_DECODER=1 python setup.py develop
BUILD_CTC_DECODER=1 python setup.py develop
```

## Usage
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/bindings/_decoder.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include <pybind11/pybind11.h>
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/bindings/_dictionary.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include <pybind11/pybind11.h>
Expand Down
31 changes: 14 additions & 17 deletions torchaudio/csrc/decoder/bindings/pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,55 +1,53 @@
#include <torch/extension.h>

// FLASHLIGHT
#include <torchaudio/csrc/decoder/bindings/_decoder.cpp>
#include <torchaudio/csrc/decoder/bindings/_dictionary.cpp>

PYBIND11_MODULE(_torchaudio_decoder, m) {
// FLASHLIGHT DECODER
#ifdef BUILD_FL_DECODER
py::enum_<SmearingMode>(m, "SmearingMode")
#ifdef BUILD_CTC_DECODER
py::enum_<SmearingMode>(m, "_SmearingMode")
.value("NONE", SmearingMode::NONE)
.value("MAX", SmearingMode::MAX)
.value("LOGADD", SmearingMode::LOGADD);

py::class_<TrieNode, TrieNodePtr>(m, "TrieNode")
py::class_<TrieNode, TrieNodePtr>(m, "_TrieNode")
.def(py::init<int>(), "idx"_a)
.def_readwrite("children", &TrieNode::children)
.def_readwrite("idx", &TrieNode::idx)
.def_readwrite("labels", &TrieNode::labels)
.def_readwrite("scores", &TrieNode::scores)
.def_readwrite("max_score", &TrieNode::maxScore);

py::class_<Trie, TriePtr>(m, "Trie")
py::class_<Trie, TriePtr>(m, "_Trie")
.def(py::init<int, int>(), "max_children"_a, "root_idx"_a)
.def("get_root", &Trie::getRoot)
.def("insert", &Trie::insert, "indices"_a, "label"_a, "score"_a)
.def("search", &Trie::search, "indices"_a)
.def("smear", &Trie::smear, "smear_mode"_a);

py::class_<LM, LMPtr, PyLM>(m, "LM")
py::class_<LM, LMPtr, PyLM>(m, "_LM")
.def(py::init<>())
.def("start", &LM::start, "start_with_nothing"_a)
.def("score", &LM::score, "state"_a, "usr_token_idx"_a)
.def("finish", &LM::finish, "state"_a);

py::class_<LMState, LMStatePtr>(m, "LMState")
py::class_<LMState, LMStatePtr>(m, "_LMState")
.def(py::init<>())
.def_readwrite("children", &LMState::children)
.def("compare", &LMState::compare, "state"_a)
.def("child", &LMState::child<LMState>, "usr_index"_a);

py::class_<KenLM, KenLMPtr, LM>(m, "KenLM")
py::class_<KenLM, KenLMPtr, LM>(m, "_KenLM")
.def(
py::init<const std::string&, const Dictionary&>(),
"path"_a,
"usr_token_dict"_a);

py::enum_<CriterionType>(m, "CriterionType")
py::enum_<CriterionType>(m, "_CriterionType")
.value("ASG", CriterionType::ASG)
.value("CTC", CriterionType::CTC);

py::class_<LexiconDecoderOptions>(m, "LexiconDecoderOptions")
py::class_<LexiconDecoderOptions>(m, "_LexiconDecoderOptions")
.def(
py::init<
const int,
Expand Down Expand Up @@ -80,7 +78,7 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
.def_readwrite("log_add", &LexiconDecoderOptions::logAdd)
.def_readwrite("criterion_type", &LexiconDecoderOptions::criterionType);

py::class_<DecodeResult>(m, "DecodeResult")
py::class_<DecodeResult>(m, "_DecodeResult")
.def(py::init<int>(), "length"_a)
.def_readwrite("score", &DecodeResult::score)
.def_readwrite("amScore", &DecodeResult::amScore)
Expand All @@ -89,7 +87,7 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
.def_readwrite("tokens", &DecodeResult::tokens);

// NB: `decode` and `decodeStep` expect raw emissions pointers.
py::class_<LexiconDecoder>(m, "LexiconDecoder")
py::class_<LexiconDecoder>(m, "_LexiconDecoder")
.def(py::init<
LexiconDecoderOptions,
const TriePtr,
Expand All @@ -116,8 +114,7 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
.def("get_all_final_hypothesis", &LexiconDecoder::getAllFinalHypothesis);


// FLASHLIGHT DICTIONARY
py::class_<Dictionary>(m, "Dictionary")
py::class_<Dictionary>(m, "_Dictionary")
.def(py::init<>())
.def(py::init<const std::string&>(), "filename"_a)
.def("entry_size", &Dictionary::entrySize)
Expand All @@ -137,7 +134,7 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
"map_indices_to_entries",
&Dictionary::mapIndicesToEntries,
"indices"_a);
m.def("create_word_dict", &createWordDict, "lexicon"_a);
m.def("load_words", &loadWords, "filename"_a, "max_words"_a = -1);
m.def("_create_word_dict", &createWordDict, "lexicon"_a);
m.def("_load_words", &loadWords, "filename"_a, "max_words"_a = -1);
#endif
}
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/Decoder.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/LexiconDecoder.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include <stdlib.h>
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/Trie.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include <math.h>
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/Trie.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/Utils.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

namespace torchaudio {
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/Utils.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/lm/KenLM.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/lm/KenLM.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/decoder/lm/LM.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/dictionary/Defines.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/dictionary/Dictionary.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include <iostream>
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/dictionary/Dictionary.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/dictionary/String.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include "torchaudio/csrc/decoder/src/dictionary/String.h"
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/dictionary/String.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/decoder/src/dictionary/System.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include "torchaudio/csrc/decoder/src/dictionary/System.h"
Expand Down
Loading

0 comments on commit 9981f77

Please sign in to comment.