-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subtract a low order n-gram LM from a high order n-gram LM #24
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,29 +42,29 @@ struct VectorHasher { // hashing function for vector<Int>. | |
static const int kPrime = 7853; | ||
}; | ||
|
||
class ArpaLmCompilerImplInterface { | ||
public: | ||
virtual ~ArpaLmCompilerImplInterface() = default; | ||
virtual void ConsumeNGram(const NGram &ngram, bool is_highest) = 0; | ||
}; | ||
|
||
namespace { | ||
|
||
typedef int32_t StateId; | ||
typedef int32_t Symbol; | ||
|
||
struct StateWeight { | ||
StateId state; | ||
float weight; // -ngram.logprob | ||
float backoff; // -ngram.backoff | ||
}; | ||
|
||
// GeneralHistKey can represent state history in an arbitrarily large n | ||
// n-gram model with symbol ids fitting int32_t. | ||
class GeneralHistKey { | ||
public: | ||
// Construct key from being and end iterators. | ||
// Construct key from begin and end iterators. | ||
template <class InputIt> | ||
GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) {} | ||
// Construct empty history key. | ||
GeneralHistKey() : vector_() {} | ||
// Return tails of the key as a GeneralHistKey. The tails of an n-gram | ||
// w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the | ||
// key class does not need this operartion). | ||
// key class does not need this operation). | ||
GeneralHistKey Tails() const { | ||
return GeneralHistKey(vector_.begin() + 1, vector_.end()); | ||
} | ||
|
@@ -118,13 +118,33 @@ class OptimizedHistKey { | |
|
||
} // namespace | ||
|
||
class ArpaLmCompilerImplInterface { | ||
public: | ||
virtual ~ArpaLmCompilerImplInterface() = default; | ||
virtual void ConsumeNGram( | ||
const NGram &ngram, bool is_highest, | ||
ArpaLmCompilerImplInterface *low_order = nullptr) = 0; | ||
virtual StateWeight GetWeight(const std::vector<int32_t> &ngram) const = 0; | ||
}; | ||
|
||
template <class HistKey> | ||
class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface { | ||
public: | ||
ArpaLmCompilerImpl(ArpaLmCompiler *parent, fst::StdVectorFst *fst, | ||
Symbol sub_eps); | ||
|
||
virtual void ConsumeNGram(const NGram &ngram, bool is_highest); | ||
void ConsumeNGram(const NGram &ngram, bool is_highest, | ||
ArpaLmCompilerImplInterface *low_order = nullptr) override; | ||
|
||
StateWeight GetWeight(const std::vector<int32_t> &ngram) const override { | ||
HistKey words(ngram.begin(), ngram.end()); | ||
auto it = this->history_.find(words); | ||
if (it == this->history_.end()) { | ||
return {0, 0, 0}; | ||
} else { | ||
return it->second; | ||
} | ||
} | ||
|
||
private: | ||
StateId AddStateWithBackoff(HistKey key, float backoff); | ||
|
@@ -137,7 +157,7 @@ class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface { | |
Symbol sub_eps_; | ||
|
||
StateId eos_state_; | ||
typedef std::unordered_map<HistKey, StateId, typename HistKey::HashType> | ||
typedef std::unordered_map<HistKey, StateWeight, typename HistKey::HashType> | ||
HistoryMap; | ||
HistoryMap history_; | ||
}; | ||
|
@@ -154,7 +174,7 @@ ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(ArpaLmCompiler *parent, | |
// The algorithm maintains state per history. The 0-gram is a special state | ||
// for empty history. All unigrams (including BOS) backoff into this state. | ||
StateId zerogram = fst_->AddState(); | ||
history_[HistKey()] = zerogram; | ||
history_[HistKey()] = {zerogram, 0, 0}; | ||
|
||
// Also, if </s> is not treated as epsilon, create a common end state for | ||
// all transitions accepting the </s>, since they do not back off. This small | ||
|
@@ -166,8 +186,9 @@ ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(ArpaLmCompiler *parent, | |
} | ||
|
||
template <class HistKey> | ||
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram, | ||
bool is_highest) { | ||
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram( | ||
const NGram &ngram, bool is_highest, | ||
ArpaLmCompilerImplInterface *low_order) { | ||
// Generally, we do the following. Suppose we are adding an n-gram "A B | ||
// C". Then find the node for "A B", add a new node for "A B C", and connect | ||
// them with the arc accepting "C" with the specified weight. Also, add a | ||
|
@@ -203,10 +224,27 @@ void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram, | |
return; | ||
} | ||
|
||
StateId source = source_it->second; | ||
StateId source = source_it->second.state; | ||
StateId dest; | ||
Symbol sym = ngram.words.back(); | ||
float weight = -ngram.logprob; | ||
float backoff = -ngram.backoff; | ||
|
||
if (low_order == nullptr) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it necessary that the lm to subtract is actually of lower order? If it would work regardless, we might as well make it as general as possible and name it accordingly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... if it can be done elegantly, we can consider generaly linear combinations of logprobs, not just one minus the other. |
||
source_it->second.weight = weight; | ||
source_it->second.backoff = backoff; | ||
} else { | ||
auto state_weight = low_order->GetWeight(ngram.words); | ||
// if ngram.words does not exist in low_order, then | ||
// state_weight.weight and state_weight.backoff are both zero | ||
|
||
// -(log_prob_high_order - log_prob_low_order) | ||
// = -log_prob_high_order + log_prob_low_order | ||
// = weight - state_weight.weight | ||
weight -= state_weight.weight; | ||
backoff -= state_weight.backoff; | ||
} | ||
|
||
if (sym == sub_eps_ || sym == 0) { | ||
KALDILM_ERR << " <eps> or disambiguation symbol " << sym | ||
<< "found in the ARPA file. "; | ||
|
@@ -228,7 +266,7 @@ void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram, | |
// so we better do not do that at all). | ||
dest = AddStateWithBackoff( | ||
HistKey(ngram.words.begin() + (is_highest ? 1 : 0), ngram.words.end()), | ||
-ngram.backoff); | ||
backoff); | ||
} | ||
|
||
if (sym == bos_symbol_) { | ||
|
@@ -260,11 +298,11 @@ StateId ArpaLmCompilerImpl<HistKey>::AddStateWithBackoff(HistKey key, | |
if (dest_it != history_.end()) { | ||
// Found an existing state in the history map. Invariant: if the state in | ||
// the map, then its backoff arc is in the FST. We are done. | ||
return dest_it->second; | ||
return dest_it->second.state; | ||
} | ||
// Otherwise create a new state and its backoff arc, and register in the map. | ||
StateId dest = fst_->AddState(); | ||
history_[key] = dest; | ||
history_[key] = {dest, 0, 0}; | ||
CreateBackoff(key.Tails(), dest, backoff); | ||
return dest; | ||
} | ||
|
@@ -286,7 +324,7 @@ inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(HistKey key, | |
// The arc should transduce either <eos> or #0 to <eps>, depending on the | ||
// epsilon substitution mode. This is the only case when input and output | ||
// label may differ. | ||
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second)); | ||
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second.state)); | ||
} | ||
|
||
ArpaLmCompiler::~ArpaLmCompiler() { | ||
|
@@ -328,7 +366,8 @@ void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) { | |
} | ||
|
||
bool is_highest = ngram.words.size() == NgramCounts().size(); | ||
impl_->ConsumeNGram(ngram, is_highest); | ||
impl_->ConsumeNGram(ngram, is_highest, | ||
low_order_ ? low_order_->impl_ : nullptr); | ||
} | ||
|
||
void ArpaLmCompiler::RemoveRedundantStates() { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .arpa2fst import arpa2fst | ||
from _kaldilm import arpa2fst2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any chance of some short comments here explaining the interface?
the only other comment I have is, of course check it doesn't change the behavior for normal inputs.