From 9248a6bc78b1ef772255d17077da19783daa52aa Mon Sep 17 00:00:00 2001 From: Chen Gong Date: Thu, 14 Mar 2019 06:57:19 +0800 Subject: [PATCH] feat(grammar): compare homophones/homographs in sentence add inteface to grammar plugin; fall back to naive formula if missing "grammar" module --- src/rime/dict/dictionary.h | 2 + src/rime/gear/grammar.h | 31 +++++++++++ src/rime/gear/poet.cc | 27 +++++++--- src/rime/gear/poet.h | 6 ++- src/rime/gear/script_translator.cc | 15 ++++-- src/rime/gear/script_translator.h | 4 ++ src/rime/gear/table_translator.cc | 81 +++++++++++++++++++++++------ src/rime/gear/table_translator.h | 3 ++ src/rime/gear/translator_commons.cc | 14 +++-- src/rime/gear/translator_commons.h | 7 ++- 10 files changed, 156 insertions(+), 34 deletions(-) create mode 100644 src/rime/gear/grammar.h diff --git a/src/rime/dict/dictionary.h b/src/rime/dict/dictionary.h index fbb307462f..f17a56ec27 100644 --- a/src/rime/dict/dictionary.h +++ b/src/rime/dict/dictionary.h @@ -45,6 +45,8 @@ class DictEntryIterator : public DictEntryFilterBinder { DictEntryIterator() = default; DictEntryIterator(DictEntryIterator&& other) = default; DictEntryIterator& operator= (DictEntryIterator&& other) = default; + DictEntryIterator(const DictEntryIterator& other) = default; + DictEntryIterator& operator= (const DictEntryIterator& other) = default; void AddChunk(dictionary::Chunk&& chunk, Table* table); void Sort(); diff --git a/src/rime/gear/grammar.h b/src/rime/gear/grammar.h new file mode 100644 index 0000000000..09923d8ac8 --- /dev/null +++ b/src/rime/gear/grammar.h @@ -0,0 +1,31 @@ +#ifndef RIME_GRAMMAR_H_ +#define RIME_GRAMMAR_H_ + +#include +#include +#include + +namespace rime { + +class Config; + +class Grammar : public Class { + public: + virtual ~Grammar() {} + virtual double Query(const string& context, + const string& word, + bool is_rear) = 0; + + inline static double Evaluate(const string& context, + const DictEntry& entry, + bool is_rear, + Grammar* grammar) { + const double kPenalty = -18.420680743952367; // log(1e-8) + return entry.weight + + (grammar ? grammar->Query(context, entry.text, is_rear) : kPenalty); + } +}; + +} // namespace rime + +#endif // RIME_GRAMMAR_H_ diff --git a/src/rime/gear/poet.cc b/src/rime/gear/poet.cc index d24d5afc46..afd3310aa6 100644 --- a/src/rime/gear/poet.cc +++ b/src/rime/gear/poet.cc @@ -6,16 +6,30 @@ // // 2011-10-06 GONG Chen // -#include #include +#include #include +#include #include namespace rime { +inline static Grammar* create_grammar(Config* config) { + if (auto* grammar = Grammar::Require("grammar")) { + return grammar->Create(config); + } + return nullptr; +} + +Poet::Poet(const Language* language, Config* config) + : language_(language), + grammar_(create_grammar(config)) {} + +Poet::~Poet() {} + an Poet::MakeSentence(const WordGraph& graph, - size_t total_length) { - const int kMaxHomophonesInMind = 1; + size_t total_length) { + // TODO: save more intermediate sentence candidates map> sentences; sentences[0] = New(language_); // dynamic programming @@ -30,15 +44,16 @@ an Poet::MakeSentence(const WordGraph& graph, continue; // exclude single words from the result DLOG(INFO) << "end pos: " << end_pos; const DictEntryList& entries(x.second); - for (size_t i = 0; i < kMaxHomophonesInMind && i < entries.size(); ++i) { + for (size_t i = 0; i < entries.size(); ++i) { const auto& entry(entries[i]); auto new_sentence = New(*sentences[start_pos]); - new_sentence->Extend(*entry, end_pos); + bool is_rear = end_pos == total_length; + new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get()); if (sentences.find(end_pos) == sentences.end() || sentences[end_pos]->weight() < new_sentence->weight()) { DLOG(INFO) << "updated sentences " << end_pos << ") with '" << new_sentence->text() << "', " << new_sentence->weight(); - sentences[end_pos] = new_sentence; + sentences[end_pos] = std::move(new_sentence); } } } diff --git a/src/rime/gear/poet.h b/src/rime/gear/poet.h index df27da5498..dfe8bcb1ce 100644 --- a/src/rime/gear/poet.h +++ b/src/rime/gear/poet.h @@ -10,6 +10,7 @@ #ifndef RIME_POET_H_ #define RIME_POET_H_ +#include #include #include @@ -17,16 +18,19 @@ namespace rime { using WordGraph = map; +class Grammar; class Language; class Poet { public: - Poet(const Language* language) : language_(language) {} + Poet(const Language* language, Config* config); + ~Poet(); an MakeSentence(const WordGraph& graph, size_t total_length); protected: const Language* language_; + the grammar_; }; } // namespace rime diff --git a/src/rime/gear/script_translator.cc b/src/rime/gear/script_translator.cc index 9d8e838659..ebf16fecf9 100644 --- a/src/rime/gear/script_translator.cc +++ b/src/rime/gear/script_translator.cc @@ -103,9 +103,11 @@ class ScriptTranslation : public Translation { public: ScriptTranslation(ScriptTranslator* translator, Corrector* corrector, + Poet* poet, const string& input, size_t start) : translator_(translator), + poet_(poet), start_(start), syllabifier_(New( translator, corrector, input, start)), @@ -124,6 +126,7 @@ class ScriptTranslation : public Translation { void PrepareCandidate(); ScriptTranslator* translator_; + Poet* poet_; size_t start_; an syllabifier_; @@ -156,6 +159,8 @@ ScriptTranslator::ScriptTranslator(const Ticket& ticket) config->GetBool(name_space_ + "/always_show_comments", &always_show_comments_); config->GetBool(name_space_ + "/enable_correction", &enable_correction_); + config->GetInt(name_space_ + "/max_homophones", &max_homophones_); + poet_.reset(new Poet(language(), config)); } if (enable_correction_) { if (auto* corrector = Corrector::Require("corrector")) { @@ -181,6 +186,7 @@ an ScriptTranslator::Query(const string& input, // the translator should survive translations it creates auto result = New(this, corrector_.get(), + poet_.get(), input, segment.start); if (!result || @@ -523,15 +529,16 @@ an ScriptTranslation::MakeSentence(Dictionary* dict, // merge lookup results for (auto& y : *phrase) { DictEntryList& entries(dest[y.first]); - if (entries.empty()) { + while (entries.size() < translator_->max_homophones() && + !y.second.exhausted()) { entries.push_back(y.second.Peek()); + if (!y.second.Next()) + break; } } } } - Poet poet(translator_->language()); - auto sentence = poet.MakeSentence(graph, - syllable_graph.interpreted_length); + auto sentence = poet_->MakeSentence(graph, syllable_graph.interpreted_length); if (sentence) { sentence->Offset(start_); sentence->set_syllabifier(syllabifier_); diff --git a/src/rime/gear/script_translator.h b/src/rime/gear/script_translator.h index 7867161d04..ea799e65e3 100644 --- a/src/rime/gear/script_translator.h +++ b/src/rime/gear/script_translator.h @@ -21,6 +21,7 @@ class Corrector; struct DictEntry; struct DictEntryCollector; class Dictionary; +class Poet; class UserDictionary; struct SyllableGraph; @@ -38,14 +39,17 @@ class ScriptTranslator : public Translator, string Spell(const Code& code); // options + int max_homophones() const { return max_homophones_; } int spelling_hints() const { return spelling_hints_; } bool always_show_comments() const { return always_show_comments_; } protected: + int max_homophones_ = 1; int spelling_hints_ = 0; bool always_show_comments_ = false; bool enable_correction_ = false; the corrector_; + the poet_; }; } // namespace rime diff --git a/src/rime/gear/table_translator.cc b/src/rime/gear/table_translator.cc index 472e837615..3fbbf0b8b0 100644 --- a/src/rime/gear/table_translator.cc +++ b/src/rime/gear/table_translator.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -216,6 +217,13 @@ TableTranslator::TableTranslator(const Ticket& ticket) &encode_commit_history_); config->GetInt(name_space_ + "/max_phrase_length", &max_phrase_length_); + config->GetInt(name_space_ + "/max_homographs", + &max_homographs_); + if (enable_sentence_ || sentence_over_completion_) { + if (auto* grammar_component = Grammar::Require("grammar")) { + grammar_.reset(grammar_component->Create(config)); + } + } } if (enable_encoder_ && user_dict_) { encoder_.reset(new UnityTableEncoder(user_dict_.get())); @@ -231,7 +239,7 @@ static bool starts_with_completion(an translation) { } an TableTranslator::Query(const string& input, - const Segment& segment) { + const Segment& segment) { if (!segment.HasTag(tag_)) return nullptr; DLOG(INFO) << "input = '" << input @@ -519,7 +527,7 @@ bool SentenceTranslation::PreferUserPhrase() const { return false; } -static size_t consume_trailing_delimiters(size_t pos, +inline static size_t consume_trailing_delimiters(size_t pos, const string& input, const string& delimiters) { while (pos < input.length() && @@ -529,11 +537,25 @@ static size_t consume_trailing_delimiters(size_t pos, return pos; } +template +inline static void collect_entries(DictEntryList& dest, + Iter& iter, + int max_entries) { + if (dest.size() < max_entries && !iter.exhausted()) { + dest.push_back(iter.Peek()); + // alters iter if collecting more than 1 entries + while (dest.size() < max_entries && iter.Next()) { + dest.push_back(iter.Peek()); + } + } +} + an TableTranslator::MakeSentence(const string& input, size_t start, bool include_prefix_phrases) { bool filter_by_charset = enable_charset_filter_ && !engine_->context()->get_option("extended_charset"); + const int max_entries = max_homographs_; DictEntryCollector collector; UserDictEntryCollector user_phrase_collector; map> sentences; @@ -543,13 +565,14 @@ TableTranslator::MakeSentence(const string& input, size_t start, continue; string active_input = input.substr(start_pos); string active_key = active_input + ' '; - vector> entries(active_input.length() + 1); + UserDictEntryCollector collected_entries; // lookup dictionaries if (user_dict_ && user_dict_->loaded()) { for (size_t len = 1; len <= active_input.length(); ++len) { size_t consumed_length = consume_trailing_delimiters(len, active_input, delimiters_); - if (entries[consumed_length]) + auto& dest(collected_entries[consumed_length]); + if (dest.size() >= max_entries) continue; DLOG(INFO) << "active input: " << active_input << "[0, " << len << ")"; UserDictEntryIterator uter; @@ -560,9 +583,15 @@ TableTranslator::MakeSentence(const string& input, size_t start, uter.AddFilter(CharsetFilter::FilterDictEntry); } if (!uter.exhausted()) { - entries[consumed_length] = uter.Peek(); + if (start_pos == 0 && max_entries > 1) { + UserDictEntryIterator uter_copy(uter); + collect_entries(dest, uter_copy, max_entries); + } else { + collect_entries(dest, uter, max_entries); + } if (start_pos == 0) { // also provide words for manual composition + // uter must not be consumed uter.Release(&user_phrase_collector[consumed_length]); DLOG(INFO) << "user phrase[" << consumed_length << "]: " << user_phrase_collector[consumed_length].size(); @@ -578,7 +607,8 @@ TableTranslator::MakeSentence(const string& input, size_t start, for (size_t len = 1; len <= active_input.length(); ++len) { size_t consumed_length = consume_trailing_delimiters(len, active_input, delimiters_); - if (entries[consumed_length]) + auto& dest(collected_entries[consumed_length]); + if (!dest.empty()) continue; DLOG(INFO) << "active input: " << active_input << "[0, " << len << ")"; UserDictEntryIterator uter; @@ -589,9 +619,15 @@ TableTranslator::MakeSentence(const string& input, size_t start, uter.AddFilter(CharsetFilter::FilterDictEntry); } if (!uter.exhausted()) { - entries[consumed_length] = uter.Peek(); + if (start_pos == 0 && max_entries > 1) { + UserDictEntryIterator uter_copy(uter); + collect_entries(dest, uter_copy, max_entries); + } else { + collect_entries(dest, uter, max_entries); + } if (start_pos == 0) { // also provide words for manual composition + // uter must not be consumed uter.Release(&user_phrase_collector[consumed_length]); DLOG(INFO) << "unity phrase[" << consumed_length << "]: " << user_phrase_collector[consumed_length].size(); @@ -612,7 +648,8 @@ TableTranslator::MakeSentence(const string& input, size_t start, continue; size_t consumed_length = consume_trailing_delimiters(m.length, active_input, delimiters_); - if (entries[consumed_length]) + auto& dest(collected_entries[consumed_length]); + if (dest.size() >= max_entries) continue; DictEntryIterator iter; dict_->LookupWords(&iter, active_input.substr(0, m.length), false); @@ -620,9 +657,15 @@ TableTranslator::MakeSentence(const string& input, size_t start, iter.AddFilter(CharsetFilter::FilterDictEntry); } if (!iter.exhausted()) { - entries[consumed_length] = iter.Peek(); + if (start_pos == 0 && max_entries - dest.size() > 1) { + DictEntryIterator iter_copy = iter; + collect_entries(dest, iter_copy, max_entries); + } else { + collect_entries(dest, iter, max_entries); + } if (start_pos == 0) { // also provide words for manual composition + // iter must not be consumed collector[consumed_length] = std::move(iter); DLOG(INFO) << "table[" << consumed_length << "]: " << collector[consumed_length].entry_count(); @@ -631,16 +674,20 @@ TableTranslator::MakeSentence(const string& input, size_t start, } } for (size_t len = 1; len <= active_input.length(); ++len) { - if (!entries[len]) + const auto& entries(collected_entries[len]); + if (entries.empty()) continue; size_t end_pos = start_pos + len; - // create a new sentence - auto new_sentence = New(*sentences[start_pos]); - new_sentence->Extend(*entries[len], end_pos); - // compare and update sentences - if (sentences.find(end_pos) == sentences.end() || - sentences[end_pos]->weight() <= new_sentence->weight()) { - sentences[end_pos] = std::move(new_sentence); + bool is_rear = end_pos == input.length(); + for (const auto& entry : entries) { + // create a new sentence + auto new_sentence = New(*sentences[start_pos]); + new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get()); + // compare and update sentences + if (sentences.find(end_pos) == sentences.end() || + sentences[end_pos]->weight() <= new_sentence->weight()) { + sentences[end_pos] = std::move(new_sentence); + } } } } diff --git a/src/rime/gear/table_translator.h b/src/rime/gear/table_translator.h index 3fcb637d86..3948fcb730 100644 --- a/src/rime/gear/table_translator.h +++ b/src/rime/gear/table_translator.h @@ -19,6 +19,7 @@ namespace rime { +class Grammar; class UnityTableEncoder; class TableTranslator : public Translator, @@ -44,7 +45,9 @@ class TableTranslator : public Translator, bool sentence_over_completion_ = false; bool encode_commit_history_ = true; int max_phrase_length_ = 5; + int max_homographs_ = 1; the encoder_; + the grammar_; }; class TableTranslation : public Translation { diff --git a/src/rime/gear/translator_commons.cc b/src/rime/gear/translator_commons.cc index 523c2f228f..63ee160484 100644 --- a/src/rime/gear/translator_commons.cc +++ b/src/rime/gear/translator_commons.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace rime { @@ -87,12 +88,15 @@ bool Spans::HasVertex(size_t vertex) const { // Sentence -void Sentence::Extend(const DictEntry& entry, size_t end_pos) { - const double kPenalty = -18.420680743952367; // log(1e-8) - entry_->code.insert(entry_->code.end(), - entry.code.begin(), entry.code.end()); +void Sentence::Extend(const DictEntry& entry, + size_t end_pos, + bool is_rear, + Grammar* grammar) { + entry_->weight += Grammar::Evaluate(entry_->text, entry, is_rear, grammar); entry_->text.append(entry.text); - entry_->weight += entry.weight + kPenalty; + entry_->code.insert(entry_->code.end(), + entry.code.begin(), + entry.code.end()); components_.push_back(entry); syllable_lengths_.push_back(end_pos - end()); set_end(end_pos); diff --git a/src/rime/gear/translator_commons.h b/src/rime/gear/translator_commons.h index 0f6b6b48c9..88e05aa97f 100644 --- a/src/rime/gear/translator_commons.h +++ b/src/rime/gear/translator_commons.h @@ -107,6 +107,8 @@ class Phrase : public Candidate { // +class Grammar; + class Sentence : public Phrase { public: Sentence(const Language* language) @@ -119,7 +121,10 @@ class Sentence : public Phrase { syllable_lengths_(other.syllable_lengths_) { entry_ = New(other.entry()); } - void Extend(const DictEntry& entry, size_t end_pos); + void Extend(const DictEntry& entry, + size_t end_pos, + bool is_rear, + Grammar* grammar); void Offset(size_t offset); const vector& components() const {