Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ctcdecode in native client (Fixes #1668) #1679

Merged
merged 14 commits into from
Oct 30, 2018
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.binary filter=lfs diff=lfs merge=lfs -crlf
data/lm/trie filter=lfs diff=lfs merge=lfs -crlf
data/lm/vocab.txt filter=lfs diff=lfs merge=lfs -text
data/lm/trie.ctcdecode filter=lfs diff=lfs merge=lfs -text
lissyx marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,9 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
n_steps=n_steps,
previous_state=previous_state)

# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits)

new_state_c, new_state_h = layers['rnn_output_state']

# Initial zero state
Expand Down
5 changes: 2 additions & 3 deletions data/lm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ binary_path = '/tmp/lm.binary'
os.remove(lm_path)
```

The trie was then generated from the list of unique words in the corpus (data/lm/vocab.txt):
The trie was then generated from the vocabulary of the language model:

```bash
tr -s '[[:space:]]' '\n' < /tmp/lower.txt | sort -u > /tmp/vocab.txt
./generate_trie ../data/alphabet.txt /tmp/lm.binary /tmp/vocab.txt /tmp/trie
./generate_trie ../data/alphabet.txt /tmp/lm.binary /tmp/trie
```
3 changes: 3 additions & 0 deletions data/lm/trie.ctcdecode
Git LFS file not shown
Binary file added data/smoke_test/vocab.trie.ctcdecode
Binary file not shown.
86 changes: 54 additions & 32 deletions native_client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,32 @@ genrule(
tools = [":ds_git_version.sh"]
)

KENLM_SOURCES = glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]) + glob(["boost_locale/**/*.hpp"])

KENLM_INCLUDES = [
"kenlm",
"boost_locale"
]

DECODER_SOURCES = glob([
"ctcdecode/*.h",
"ctcdecode/*.cpp",
"ctcdecode/third_party/openfst-1.6.7/src/lib/*.cc"
]) + KENLM_SOURCES

DECODER_INCLUDES = [
".",
"ctcdecode/third_party/openfst-1.6.7/src/include",
"ctcdecode/third_party/ThreadPool"
] + KENLM_INCLUDES

tf_cc_shared_object(
name = "libdeepspeech.so",
srcs = ["deepspeech.cc",
"deepspeech.h",
"alphabet.h",
"beam_search.h",
"trie_node.h",
"c_speech_features/c_speech_features.cpp",
"kiss_fft130/kiss_fft.c",
"kiss_fft130/tools/kiss_fftr.c",
Expand All @@ -26,19 +45,34 @@ tf_cc_shared_object(
"kiss_fft130/_kiss_fft_guts.h",
"kiss_fft130/tools/kiss_fftr.h",
"ds_version.h"] +
glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]) +
glob(["boost_locale/**/*.hpp"]),
DECODER_SOURCES,
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
# which makes it harder to see our own warnings
copts = ["-Wno-sign-compare", "-fvisibility=hidden"],
linkopts = select({
"//tensorflow:darwin": [],
lissyx marked this conversation as resolved.
Show resolved Hide resolved
"//conditions:default": [
"-Wl,-Bsymbolic",
"-Wl,-Bsymbolic-functions",
"-Wl,-export-dynamic",
"//tensorflow:linux_x86_64": [
reuben marked this conversation as resolved.
Show resolved Hide resolved
"-ldl",
"-pthread",
"-Wl,-Bsymbolic",
"-Wl,-Bsymbolic-functions",
"-Wl,-export-dynamic",
],
"//tensorflow:rpi3": [
"-ldl",
"-pthread",
"-Wl,-Bsymbolic",
"-Wl,-Bsymbolic-functions",
"-Wl,-export-dynamic",
"-l:libstdc++.a",
],
"//tensorflow:rpi3-armv8": [
"-ldl",
"-pthread",
"-Wl,-Bsymbolic",
"-Wl,-Bsymbolic-functions",
"-Wl,-export-dynamic",
"-l:libstdc++.a",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still lacks some trivial factorization ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I did the review commit by commit, and this is fixed by a newer one, so disregard that comment (github flagged it as "outdated")

],
}),
deps = [
Expand All @@ -54,6 +88,7 @@ tf_cc_shared_object(
"//tensorflow/core/kernels:constant_op", # Const
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst
"//tensorflow/core/kernels:identity_op", # Identity
"//tensorflow/core/kernels:softmax_op", # Softmax
reuben marked this conversation as resolved.
Show resolved Hide resolved
"//tensorflow/core/kernels:transpose_op", # Transpose
"//tensorflow/core/kernels:reshape_op", # Reshape
"//tensorflow/core/kernels:shape_ops", # Shape
Expand All @@ -76,7 +111,7 @@ tf_cc_shared_object(
] + if_cuda([
"//tensorflow/core:core",
]),
includes = ["kenlm", "boost_locale", "c_speech_features", "kiss_fft130"],
includes = ["c_speech_features", "kiss_fft130"] + DECODER_INCLUDES,
defines = ["KENLM_MAX_ORDER=6"],
)

Expand All @@ -88,11 +123,8 @@ tf_cc_shared_object(
"alphabet.h",
"trie_node.h"
] +
glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]) +
glob(["boost_locale/**/*.hpp"]),
includes = ["kenlm", "boost_locale"],
KENLM_SOURCES,
includes = KENLM_INCLUDES,
copts = ["-std=c++11"],
defines = ["KENLM_MAX_ORDER=6"],
deps = ["//tensorflow/core:framework_headers_lib",
Expand All @@ -105,32 +137,22 @@ cc_binary(
name = "generate_trie",
srcs = [
"generate_trie.cpp",
"trie_node.h",
"alphabet.h",
] +
glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]) +
glob(["boost_locale/**/*.hpp"]),
includes = ["kenlm", "boost_locale"],
] + DECODER_SOURCES,
includes = DECODER_INCLUDES,
copts = ["-std=c++11"],
linkopts = ["-lm"],
linkopts = ["-lm", "-ldl", "-pthread"],
defines = ["KENLM_MAX_ORDER=6"],
)

cc_binary(
name = "trie_load",
srcs = [
"trie_load.cc",
"trie_node.h",
"alphabet.h",
] +
glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]) +
glob(["boost_locale/**/*.hpp"]),
includes = ["kenlm", "boost_locale"],
] + DECODER_SOURCES,
includes = DECODER_INCLUDES,
copts = ["-std=c++11"],
linkopts = ["-lm"],
linkopts = ["-lm", "-ldl", "-pthread"],
defines = ["KENLM_MAX_ORDER=6"],
)
39 changes: 26 additions & 13 deletions native_client/alphabet.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>

/*
* Loads a text file describing a mapping of labels to strings, one string per
Expand All @@ -17,29 +18,29 @@ class Alphabet {
Alphabet(const char *config_file) {
std::ifstream in(config_file, std::ios::in);
unsigned int label = 0;
space_label_ = -2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates a strange dependency between Alphabet and ctc_beam_search_decoder, both know about this magic number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where ctc_beam_search_decoder knows about this magic number? It's supposed to not match any real label if there's no space label in the alphabet. Although now I realize that there's actually a subtle bug here, since space_label_ is unsigned, this is assigning a missing space label to UINT_MAX-1, which could break for an alphabet that's exactly UINT_MAX long :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it I don't think this is a big deal as languages without a space label will always use character-based language models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true that "languages without a space label will always use character-based language models"? For example Thai has an alphabet but words can also can be written without spaces. I think Javanese is the same way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The -2 constant would be a problem for a language that uses a word-based language model and has an alphabet that is 2^32-2 characters long. AFAIK, the CJK languages are the only ones that could possibly get close to that many characters, and they would all probably use character-level LMs. Unicode defines close to 100,000 CJK ideograms, so even then we still have a lot of margin. I can add a separate "has_space" flag to remove the in-band signaling here, but I don't think it would be a problem anytime soon.

for (std::string line; std::getline(in, line);) {
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
line = '#';
} else if (line[0] == '#') {
continue;
}
label_to_str_[label] = line;
//TODO: we should probably do something more i18n-aware here
lissyx marked this conversation as resolved.
Show resolved Hide resolved
if (line == " ") {
space_label_ = label;
}
label_to_str_.push_back(line);
str_to_label_[line] = label;
++label;
}
label_to_str_.push_back("*");
reuben marked this conversation as resolved.
Show resolved Hide resolved
size_ = label;
in.close();
}

const std::string& StringFromLabel(unsigned int label) const {
assert(label < size_);
auto it = label_to_str_.find(label);
if (it != label_to_str_.end()) {
return it->second;
} else {
// unreachable due to assert above
abort();
}
return label_to_str_[label];
}

unsigned int LabelFromString(const std::string& string) const {
Expand All @@ -52,19 +53,31 @@ class Alphabet {
}
}

size_t GetSize() {
size_t GetSize() const {
return size_;
}

bool IsSpace(unsigned int label) const {
//TODO: we should probably do something more i18n-aware here
const std::string& str = StringFromLabel(label);
return str.size() == 1 && str[0] == ' ';
return label == space_label_;
}

unsigned int GetSpaceLabel() const {
return space_label_;
}

template <typename T>
std::string LabelsToString(const std::vector<T>& input) const {
reuben marked this conversation as resolved.
Show resolved Hide resolved
std::string word;
for (auto ind : input) {
word += StringFromLabel(ind);
}
return word;
}

private:
size_t size_;
std::unordered_map<unsigned int, std::string> label_to_str_;
unsigned int space_label_;
std::vector<std::string> label_to_str_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why switch from std::unordered_map to std::vector? Just curious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ctcdecode code uses a vector<string> as the alphabet representation so I initially was just passing Alphabet's vector to it, but eventually I made it use the Alphabet class directly instead. But std::unordered_map is unnecessary here as a vector is sufficient and faster/leaner.

std::unordered_map<std::string, unsigned int> str_to_label_;
};

Expand Down
2 changes: 1 addition & 1 deletion native_client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ main(int argc, char **argv)
return 1;
}

if (lm && trie) {
if (lm) {
reuben marked this conversation as resolved.
Show resolved Hide resolved
int status = DS_EnableDecoderWithLM(ctx,
alphabet,
lm,
Expand Down
22 changes: 22 additions & 0 deletions native_client/ctcdecode/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
MIT License

Copyright (c) 2017 Ryan Leary

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

2 changes: 2 additions & 0 deletions native_client/ctcdecode/README.mozilla
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Originally imported from https://github.com/parlance/ctcdecode, commit 140b45860cec6671fb0bf6dbb675073241c0f9b0
reuben marked this conversation as resolved.
Show resolved Hide resolved

Loading