Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subtract a low order n-gram LM from a high order n-gram LM #24

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions kaldilm/csrc/arpa_lm_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,29 @@ struct VectorHasher { // hashing function for vector<Int>.
static const int kPrime = 7853;
};

class ArpaLmCompilerImplInterface {
public:
virtual ~ArpaLmCompilerImplInterface() = default;
virtual void ConsumeNGram(const NGram &ngram, bool is_highest) = 0;
};

namespace {

typedef int32_t StateId;
typedef int32_t Symbol;

struct StateWeight {
StateId state;
float weight; // -ngram.logprob
float backoff; // -ngram.backoff
};

// GeneralHistKey can represent state history in an arbitrarily large n
// n-gram model with symbol ids fitting int32_t.
class GeneralHistKey {
public:
// Construct key from being and end iterators.
// Construct key from begin and end iterators.
template <class InputIt>
GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) {}
// Construct empty history key.
GeneralHistKey() : vector_() {}
// Return tails of the key as a GeneralHistKey. The tails of an n-gram
// w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the
// key class does not need this operartion).
// key class does not need this operation).
GeneralHistKey Tails() const {
return GeneralHistKey(vector_.begin() + 1, vector_.end());
}
Expand Down Expand Up @@ -118,13 +118,33 @@ class OptimizedHistKey {

} // namespace

class ArpaLmCompilerImplInterface {
public:
virtual ~ArpaLmCompilerImplInterface() = default;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any chance of some short comments here explaining the interface?
the only other comment I have is, of course check it doesn't change the behavior for normal inputs.

virtual void ConsumeNGram(
const NGram &ngram, bool is_highest,
ArpaLmCompilerImplInterface *low_order = nullptr) = 0;
virtual StateWeight GetWeight(const std::vector<int32_t> &ngram) const = 0;
};

template <class HistKey>
class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
public:
ArpaLmCompilerImpl(ArpaLmCompiler *parent, fst::StdVectorFst *fst,
Symbol sub_eps);

virtual void ConsumeNGram(const NGram &ngram, bool is_highest);
void ConsumeNGram(const NGram &ngram, bool is_highest,
ArpaLmCompilerImplInterface *low_order = nullptr) override;

StateWeight GetWeight(const std::vector<int32_t> &ngram) const override {
HistKey words(ngram.begin(), ngram.end());
auto it = this->history_.find(words);
if (it == this->history_.end()) {
return {0, 0, 0};
} else {
return it->second;
}
}

private:
StateId AddStateWithBackoff(HistKey key, float backoff);
Expand All @@ -137,7 +157,7 @@ class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
Symbol sub_eps_;

StateId eos_state_;
typedef std::unordered_map<HistKey, StateId, typename HistKey::HashType>
typedef std::unordered_map<HistKey, StateWeight, typename HistKey::HashType>
HistoryMap;
HistoryMap history_;
};
Expand All @@ -154,7 +174,7 @@ ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(ArpaLmCompiler *parent,
// The algorithm maintains state per history. The 0-gram is a special state
// for empty history. All unigrams (including BOS) backoff into this state.
StateId zerogram = fst_->AddState();
history_[HistKey()] = zerogram;
history_[HistKey()] = {zerogram, 0, 0};

// Also, if </s> is not treated as epsilon, create a common end state for
// all transitions accepting the </s>, since they do not back off. This small
Expand All @@ -166,8 +186,9 @@ ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(ArpaLmCompiler *parent,
}

template <class HistKey>
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram,
bool is_highest) {
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(
const NGram &ngram, bool is_highest,
ArpaLmCompilerImplInterface *low_order) {
// Generally, we do the following. Suppose we are adding an n-gram "A B
// C". Then find the node for "A B", add a new node for "A B C", and connect
// them with the arc accepting "C" with the specified weight. Also, add a
Expand Down Expand Up @@ -203,10 +224,27 @@ void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram,
return;
}

StateId source = source_it->second;
StateId source = source_it->second.state;
StateId dest;
Symbol sym = ngram.words.back();
float weight = -ngram.logprob;
float backoff = -ngram.backoff;

if (low_order == nullptr) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it necessary that the lm to subtract is actually of lower order? If it would work regardless, we might as well make it as general as possible and name it accordingly.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... if it can be done elegantly, we can consider generaly linear combinations of logprobs, not just one minus the other.

source_it->second.weight = weight;
source_it->second.backoff = backoff;
} else {
auto state_weight = low_order->GetWeight(ngram.words);
// if ngram.words does not exist in low_order, then
// state_weight.weight and state_weight.backoff are both zero

// -(log_prob_high_order - log_prob_low_order)
// = -log_prob_high_order + log_prob_low_order
// = weight - state_weight.weight
weight -= state_weight.weight;
backoff -= state_weight.backoff;
}

if (sym == sub_eps_ || sym == 0) {
KALDILM_ERR << " <eps> or disambiguation symbol " << sym
<< "found in the ARPA file. ";
Expand All @@ -228,7 +266,7 @@ void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram,
// so we better do not do that at all).
dest = AddStateWithBackoff(
HistKey(ngram.words.begin() + (is_highest ? 1 : 0), ngram.words.end()),
-ngram.backoff);
backoff);
}

if (sym == bos_symbol_) {
Expand Down Expand Up @@ -260,11 +298,11 @@ StateId ArpaLmCompilerImpl<HistKey>::AddStateWithBackoff(HistKey key,
if (dest_it != history_.end()) {
// Found an existing state in the history map. Invariant: if the state in
// the map, then its backoff arc is in the FST. We are done.
return dest_it->second;
return dest_it->second.state;
}
// Otherwise create a new state and its backoff arc, and register in the map.
StateId dest = fst_->AddState();
history_[key] = dest;
history_[key] = {dest, 0, 0};
CreateBackoff(key.Tails(), dest, backoff);
return dest;
}
Expand All @@ -286,7 +324,7 @@ inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(HistKey key,
// The arc should transduce either <eos> or #0 to <eps>, depending on the
// epsilon substitution mode. This is the only case when input and output
// label may differ.
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second.state));
}

ArpaLmCompiler::~ArpaLmCompiler() {
Expand Down Expand Up @@ -328,7 +366,8 @@ void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) {
}

bool is_highest = ngram.words.size() == NgramCounts().size();
impl_->ConsumeNGram(ngram, is_highest);
impl_->ConsumeNGram(ngram, is_highest,
low_order_ ? low_order_->impl_ : nullptr);
}

void ArpaLmCompiler::RemoveRedundantStates() {
Expand Down
10 changes: 7 additions & 3 deletions kaldilm/csrc/arpa_lm_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ class ArpaLmCompilerImplInterface;
class ArpaLmCompiler : public ArpaFileParser {
public:
ArpaLmCompiler(const ArpaParseOptions &options, int sub_eps,
fst::SymbolTable *symbols)
: ArpaFileParser(options, symbols), sub_eps_(sub_eps), impl_(nullptr) {}
fst::SymbolTable *symbols, ArpaLmCompiler *low_order = nullptr)
: ArpaFileParser(options, symbols),
sub_eps_(sub_eps),
impl_(nullptr),
low_order_(low_order) {}
~ArpaLmCompiler();

const fst::StdVectorFst &Fst() const { return fst_; }
Expand All @@ -37,7 +40,8 @@ class ArpaLmCompiler : public ArpaFileParser {
void Check() const;

int sub_eps_;
ArpaLmCompilerImplInterface *impl_; // Owned.
ArpaLmCompilerImplInterface *impl_; // Owned.
ArpaLmCompiler *low_order_ = nullptr; // Owned.
fst::StdVectorFst fst_;
template <class HistKey>
friend class ArpaLmCompilerImpl;
Expand Down
2 changes: 1 addition & 1 deletion kaldilm/csrc/arpa_lm_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ ArpaLmCompiler *Compile(bool seps, const std::string &infile) {
// Tests in this form cannot be run with epsilon substitution, unless every
// random path is also fitted with a #0-transducing self-loop.
ArpaLmCompiler *lm_compiler =
new ArpaLmCompiler(options, seps ? kDisambig : 0, &symbols);
new ArpaLmCompiler(options, seps ? kDisambig : 0, &symbols, nullptr);
{
std::ifstream inf(infile);
lm_compiler->Read(inf);
Expand Down
106 changes: 106 additions & 0 deletions kaldilm/python/csrc/kaldilm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,104 @@ std::string Arpa2Fst(const std::string &input_arpa,
return os.str();
}

std::string Arpa2Fst2(
const std::string &input_arpa, const std::string &input_arpa2,
const std::string &output_fst = "", const std::string bos_symbol = "<s>",
const std::string &disambig_symbol = "",
const std::string &eos_symbol = "</s>", bool ilabel_sort = true,
bool keep_symbols = false, int32_t max_arpa_warnings = 30,
const std::string &read_symbol_table = "",
const std::string &write_symbol_table = "", int32_t max_order = -1) {
ArpaParseOptions options;
options.max_order = max_order;
options.max_warnings = max_arpa_warnings;

std::string read_syms_filename = read_symbol_table;
std::string write_syms_filename = write_symbol_table;

std::string arpa_rxfilename = input_arpa;
std::string arpa_rxfilename2 = input_arpa2;
std::string fst_wxfilename = output_fst;

int64 disambig_symbol_id = 0;

fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
std::ifstream kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(kisym, read_syms_filename);
if (symbols == nullptr)
KALDILM_ERR << "Could not read symbol table from file "
<< read_syms_filename;

options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDILM_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(fst_wxfilename);
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}

// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);

// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;

// Actually compile LM.
KALDILM_ASSERT(symbols != nullptr);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
std::fstream ki(arpa_rxfilename);
lm_compiler.Read(ki);
}

// TODO(fangjun): Don't share the symbol table
ArpaLmCompiler lm_compiler2(options, disambig_symbol_id, symbols,
&lm_compiler);
{
std::fstream ki(arpa_rxfilename2);
lm_compiler2.Read(ki);
}

// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler2.MutableFst(), fst::StdILabelCompare());
}

// Write symbols if requested.
if (!write_syms_filename.empty()) {
std::ofstream kosym(write_syms_filename);
symbols->WriteText(kosym);
}

// Write LM FST.
if (fst_wxfilename.size() > 0) {
std::ofstream kofst(fst_wxfilename, std::ios::binary);
fst::FstWriteOptions wopts(fst_wxfilename);
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler2.Fst().Write(kofst, wopts);
}

delete symbols;

std::ostringstream os;
PrintFstInTextFormat<fst::StdArc>(os, lm_compiler2.Fst());
return os.str();
}

} // namespace kaldilm

PYBIND11_MODULE(_kaldilm, m) {
Expand All @@ -143,4 +241,12 @@ PYBIND11_MODULE(_kaldilm, m) {
py::arg("ilabel_sort") = true, py::arg("keep_symbols") = false,
py::arg("max_arpa_warnings") = 30, py::arg("read_symbol_table") = "",
py::arg("write_symbol_table") = "", py::arg("max_order") = -1);

m.def("arpa2fst2", &kaldilm::Arpa2Fst2, py::arg("input_arpa"),
py::arg("input_arpa2"), py::arg("output_fst") = "",
py::arg("bos_symbol") = "<s>", py::arg("disambig_symbol") = "",
py::arg("eos_symbol") = "</s>", py::arg("ilabel_sort") = true,
py::arg("keep_symbols") = false, py::arg("max_arpa_warnings") = 30,
py::arg("read_symbol_table") = "", py::arg("write_symbol_table") = "",
py::arg("max_order") = -1);
}
1 change: 1 addition & 0 deletions kaldilm/python/kaldilm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .arpa2fst import arpa2fst
from _kaldilm import arpa2fst2