Skip to content

Commit 9248a6b

Browse files
committed
feat(grammar): compare homophones/homographs in sentence
add inteface to grammar plugin; fall back to naive formula if missing "grammar" module
1 parent fcf36bc commit 9248a6b

10 files changed

+156
-34
lines changed

src/rime/dict/dictionary.h

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class DictEntryIterator : public DictEntryFilterBinder {
4545
DictEntryIterator() = default;
4646
DictEntryIterator(DictEntryIterator&& other) = default;
4747
DictEntryIterator& operator= (DictEntryIterator&& other) = default;
48+
DictEntryIterator(const DictEntryIterator& other) = default;
49+
DictEntryIterator& operator= (const DictEntryIterator& other) = default;
4850

4951
void AddChunk(dictionary::Chunk&& chunk, Table* table);
5052
void Sort();

src/rime/gear/grammar.h

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#ifndef RIME_GRAMMAR_H_
2+
#define RIME_GRAMMAR_H_
3+
4+
#include <rime/common.h>
5+
#include <rime/component.h>
6+
#include <rime/dict/vocabulary.h>
7+
8+
namespace rime {
9+
10+
class Config;
11+
12+
class Grammar : public Class<Grammar, Config*> {
13+
public:
14+
virtual ~Grammar() {}
15+
virtual double Query(const string& context,
16+
const string& word,
17+
bool is_rear) = 0;
18+
19+
inline static double Evaluate(const string& context,
20+
const DictEntry& entry,
21+
bool is_rear,
22+
Grammar* grammar) {
23+
const double kPenalty = -18.420680743952367; // log(1e-8)
24+
return entry.weight +
25+
(grammar ? grammar->Query(context, entry.text, is_rear) : kPenalty);
26+
}
27+
};
28+
29+
} // namespace rime
30+
31+
#endif // RIME_GRAMMAR_H_

src/rime/gear/poet.cc

+21-6
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,30 @@
66
//
77
// 2011-10-06 GONG Chen <chen.sst@gmail.com>
88
//
9-
#include <rime/common.h>
109
#include <rime/candidate.h>
10+
#include <rime/config.h>
1111
#include <rime/dict/vocabulary.h>
12+
#include <rime/gear/grammar.h>
1213
#include <rime/gear/poet.h>
1314

1415
namespace rime {
1516

17+
inline static Grammar* create_grammar(Config* config) {
18+
if (auto* grammar = Grammar::Require("grammar")) {
19+
return grammar->Create(config);
20+
}
21+
return nullptr;
22+
}
23+
24+
Poet::Poet(const Language* language, Config* config)
25+
: language_(language),
26+
grammar_(create_grammar(config)) {}
27+
28+
Poet::~Poet() {}
29+
1630
an<Sentence> Poet::MakeSentence(const WordGraph& graph,
17-
size_t total_length) {
18-
const int kMaxHomophonesInMind = 1;
31+
size_t total_length) {
32+
// TODO: save more intermediate sentence candidates
1933
map<int, an<Sentence>> sentences;
2034
sentences[0] = New<Sentence>(language_);
2135
// dynamic programming
@@ -30,15 +44,16 @@ an<Sentence> Poet::MakeSentence(const WordGraph& graph,
3044
continue; // exclude single words from the result
3145
DLOG(INFO) << "end pos: " << end_pos;
3246
const DictEntryList& entries(x.second);
33-
for (size_t i = 0; i < kMaxHomophonesInMind && i < entries.size(); ++i) {
47+
for (size_t i = 0; i < entries.size(); ++i) {
3448
const auto& entry(entries[i]);
3549
auto new_sentence = New<Sentence>(*sentences[start_pos]);
36-
new_sentence->Extend(*entry, end_pos);
50+
bool is_rear = end_pos == total_length;
51+
new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get());
3752
if (sentences.find(end_pos) == sentences.end() ||
3853
sentences[end_pos]->weight() < new_sentence->weight()) {
3954
DLOG(INFO) << "updated sentences " << end_pos << ") with '"
4055
<< new_sentence->text() << "', " << new_sentence->weight();
41-
sentences[end_pos] = new_sentence;
56+
sentences[end_pos] = std::move(new_sentence);
4257
}
4358
}
4459
}

src/rime/gear/poet.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,27 @@
1010
#ifndef RIME_POET_H_
1111
#define RIME_POET_H_
1212

13+
#include <rime/common.h>
1314
#include <rime/dict/user_dictionary.h>
1415
#include <rime/gear/translator_commons.h>
1516

1617
namespace rime {
1718

1819
using WordGraph = map<int, UserDictEntryCollector>;
1920

21+
class Grammar;
2022
class Language;
2123

2224
class Poet {
2325
public:
24-
Poet(const Language* language) : language_(language) {}
26+
Poet(const Language* language, Config* config);
27+
~Poet();
2528

2629
an<Sentence> MakeSentence(const WordGraph& graph, size_t total_length);
2730

2831
protected:
2932
const Language* language_;
33+
the<Grammar> grammar_;
3034
};
3135

3236
} // namespace rime

src/rime/gear/script_translator.cc

+11-4
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,11 @@ class ScriptTranslation : public Translation {
103103
public:
104104
ScriptTranslation(ScriptTranslator* translator,
105105
Corrector* corrector,
106+
Poet* poet,
106107
const string& input,
107108
size_t start)
108109
: translator_(translator),
110+
poet_(poet),
109111
start_(start),
110112
syllabifier_(New<ScriptSyllabifier>(
111113
translator, corrector, input, start)),
@@ -124,6 +126,7 @@ class ScriptTranslation : public Translation {
124126
void PrepareCandidate();
125127

126128
ScriptTranslator* translator_;
129+
Poet* poet_;
127130
size_t start_;
128131
an<ScriptSyllabifier> syllabifier_;
129132

@@ -156,6 +159,8 @@ ScriptTranslator::ScriptTranslator(const Ticket& ticket)
156159
config->GetBool(name_space_ + "/always_show_comments",
157160
&always_show_comments_);
158161
config->GetBool(name_space_ + "/enable_correction", &enable_correction_);
162+
config->GetInt(name_space_ + "/max_homophones", &max_homophones_);
163+
poet_.reset(new Poet(language(), config));
159164
}
160165
if (enable_correction_) {
161166
if (auto* corrector = Corrector::Require("corrector")) {
@@ -181,6 +186,7 @@ an<Translation> ScriptTranslator::Query(const string& input,
181186
// the translator should survive translations it creates
182187
auto result = New<ScriptTranslation>(this,
183188
corrector_.get(),
189+
poet_.get(),
184190
input,
185191
segment.start);
186192
if (!result ||
@@ -523,15 +529,16 @@ an<Sentence> ScriptTranslation::MakeSentence(Dictionary* dict,
523529
// merge lookup results
524530
for (auto& y : *phrase) {
525531
DictEntryList& entries(dest[y.first]);
526-
if (entries.empty()) {
532+
while (entries.size() < translator_->max_homophones() &&
533+
!y.second.exhausted()) {
527534
entries.push_back(y.second.Peek());
535+
if (!y.second.Next())
536+
break;
528537
}
529538
}
530539
}
531540
}
532-
Poet poet(translator_->language());
533-
auto sentence = poet.MakeSentence(graph,
534-
syllable_graph.interpreted_length);
541+
auto sentence = poet_->MakeSentence(graph, syllable_graph.interpreted_length);
535542
if (sentence) {
536543
sentence->Offset(start_);
537544
sentence->set_syllabifier(syllabifier_);

src/rime/gear/script_translator.h

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class Corrector;
2121
struct DictEntry;
2222
struct DictEntryCollector;
2323
class Dictionary;
24+
class Poet;
2425
class UserDictionary;
2526
struct SyllableGraph;
2627

@@ -38,14 +39,17 @@ class ScriptTranslator : public Translator,
3839
string Spell(const Code& code);
3940

4041
// options
42+
int max_homophones() const { return max_homophones_; }
4143
int spelling_hints() const { return spelling_hints_; }
4244
bool always_show_comments() const { return always_show_comments_; }
4345

4446
protected:
47+
int max_homophones_ = 1;
4548
int spelling_hints_ = 0;
4649
bool always_show_comments_ = false;
4750
bool enable_correction_ = false;
4851
the<Corrector> corrector_;
52+
the<Poet> poet_;
4953
};
5054

5155
} // namespace rime

src/rime/gear/table_translator.cc

+64-17
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <rime/dict/dictionary.h>
1818
#include <rime/dict/user_dictionary.h>
1919
#include <rime/gear/charset_filter.h>
20+
#include <rime/gear/grammar.h>
2021
#include <rime/gear/table_translator.h>
2122
#include <rime/gear/translator_commons.h>
2223
#include <rime/gear/unity_table_encoder.h>
@@ -216,6 +217,13 @@ TableTranslator::TableTranslator(const Ticket& ticket)
216217
&encode_commit_history_);
217218
config->GetInt(name_space_ + "/max_phrase_length",
218219
&max_phrase_length_);
220+
config->GetInt(name_space_ + "/max_homographs",
221+
&max_homographs_);
222+
if (enable_sentence_ || sentence_over_completion_) {
223+
if (auto* grammar_component = Grammar::Require("grammar")) {
224+
grammar_.reset(grammar_component->Create(config));
225+
}
226+
}
219227
}
220228
if (enable_encoder_ && user_dict_) {
221229
encoder_.reset(new UnityTableEncoder(user_dict_.get()));
@@ -231,7 +239,7 @@ static bool starts_with_completion(an<Translation> translation) {
231239
}
232240

233241
an<Translation> TableTranslator::Query(const string& input,
234-
const Segment& segment) {
242+
const Segment& segment) {
235243
if (!segment.HasTag(tag_))
236244
return nullptr;
237245
DLOG(INFO) << "input = '" << input
@@ -519,7 +527,7 @@ bool SentenceTranslation::PreferUserPhrase() const {
519527
return false;
520528
}
521529

522-
static size_t consume_trailing_delimiters(size_t pos,
530+
inline static size_t consume_trailing_delimiters(size_t pos,
523531
const string& input,
524532
const string& delimiters) {
525533
while (pos < input.length() &&
@@ -529,11 +537,25 @@ static size_t consume_trailing_delimiters(size_t pos,
529537
return pos;
530538
}
531539

540+
template <class Iter>
541+
inline static void collect_entries(DictEntryList& dest,
542+
Iter& iter,
543+
int max_entries) {
544+
if (dest.size() < max_entries && !iter.exhausted()) {
545+
dest.push_back(iter.Peek());
546+
// alters iter if collecting more than 1 entries
547+
while (dest.size() < max_entries && iter.Next()) {
548+
dest.push_back(iter.Peek());
549+
}
550+
}
551+
}
552+
532553
an<Translation>
533554
TableTranslator::MakeSentence(const string& input, size_t start,
534555
bool include_prefix_phrases) {
535556
bool filter_by_charset = enable_charset_filter_ &&
536557
!engine_->context()->get_option("extended_charset");
558+
const int max_entries = max_homographs_;
537559
DictEntryCollector collector;
538560
UserDictEntryCollector user_phrase_collector;
539561
map<int, an<Sentence>> sentences;
@@ -543,13 +565,14 @@ TableTranslator::MakeSentence(const string& input, size_t start,
543565
continue;
544566
string active_input = input.substr(start_pos);
545567
string active_key = active_input + ' ';
546-
vector<of<DictEntry>> entries(active_input.length() + 1);
568+
UserDictEntryCollector collected_entries;
547569
// lookup dictionaries
548570
if (user_dict_ && user_dict_->loaded()) {
549571
for (size_t len = 1; len <= active_input.length(); ++len) {
550572
size_t consumed_length =
551573
consume_trailing_delimiters(len, active_input, delimiters_);
552-
if (entries[consumed_length])
574+
auto& dest(collected_entries[consumed_length]);
575+
if (dest.size() >= max_entries)
553576
continue;
554577
DLOG(INFO) << "active input: " << active_input << "[0, " << len << ")";
555578
UserDictEntryIterator uter;
@@ -560,9 +583,15 @@ TableTranslator::MakeSentence(const string& input, size_t start,
560583
uter.AddFilter(CharsetFilter::FilterDictEntry);
561584
}
562585
if (!uter.exhausted()) {
563-
entries[consumed_length] = uter.Peek();
586+
if (start_pos == 0 && max_entries > 1) {
587+
UserDictEntryIterator uter_copy(uter);
588+
collect_entries(dest, uter_copy, max_entries);
589+
} else {
590+
collect_entries(dest, uter, max_entries);
591+
}
564592
if (start_pos == 0) {
565593
// also provide words for manual composition
594+
// uter must not be consumed
566595
uter.Release(&user_phrase_collector[consumed_length]);
567596
DLOG(INFO) << "user phrase[" << consumed_length << "]: "
568597
<< user_phrase_collector[consumed_length].size();
@@ -578,7 +607,8 @@ TableTranslator::MakeSentence(const string& input, size_t start,
578607
for (size_t len = 1; len <= active_input.length(); ++len) {
579608
size_t consumed_length =
580609
consume_trailing_delimiters(len, active_input, delimiters_);
581-
if (entries[consumed_length])
610+
auto& dest(collected_entries[consumed_length]);
611+
if (!dest.empty())
582612
continue;
583613
DLOG(INFO) << "active input: " << active_input << "[0, " << len << ")";
584614
UserDictEntryIterator uter;
@@ -589,9 +619,15 @@ TableTranslator::MakeSentence(const string& input, size_t start,
589619
uter.AddFilter(CharsetFilter::FilterDictEntry);
590620
}
591621
if (!uter.exhausted()) {
592-
entries[consumed_length] = uter.Peek();
622+
if (start_pos == 0 && max_entries > 1) {
623+
UserDictEntryIterator uter_copy(uter);
624+
collect_entries(dest, uter_copy, max_entries);
625+
} else {
626+
collect_entries(dest, uter, max_entries);
627+
}
593628
if (start_pos == 0) {
594629
// also provide words for manual composition
630+
// uter must not be consumed
595631
uter.Release(&user_phrase_collector[consumed_length]);
596632
DLOG(INFO) << "unity phrase[" << consumed_length << "]: "
597633
<< user_phrase_collector[consumed_length].size();
@@ -612,17 +648,24 @@ TableTranslator::MakeSentence(const string& input, size_t start,
612648
continue;
613649
size_t consumed_length =
614650
consume_trailing_delimiters(m.length, active_input, delimiters_);
615-
if (entries[consumed_length])
651+
auto& dest(collected_entries[consumed_length]);
652+
if (dest.size() >= max_entries)
616653
continue;
617654
DictEntryIterator iter;
618655
dict_->LookupWords(&iter, active_input.substr(0, m.length), false);
619656
if (filter_by_charset) {
620657
iter.AddFilter(CharsetFilter::FilterDictEntry);
621658
}
622659
if (!iter.exhausted()) {
623-
entries[consumed_length] = iter.Peek();
660+
if (start_pos == 0 && max_entries - dest.size() > 1) {
661+
DictEntryIterator iter_copy = iter;
662+
collect_entries(dest, iter_copy, max_entries);
663+
} else {
664+
collect_entries(dest, iter, max_entries);
665+
}
624666
if (start_pos == 0) {
625667
// also provide words for manual composition
668+
// iter must not be consumed
626669
collector[consumed_length] = std::move(iter);
627670
DLOG(INFO) << "table[" << consumed_length << "]: "
628671
<< collector[consumed_length].entry_count();
@@ -631,16 +674,20 @@ TableTranslator::MakeSentence(const string& input, size_t start,
631674
}
632675
}
633676
for (size_t len = 1; len <= active_input.length(); ++len) {
634-
if (!entries[len])
677+
const auto& entries(collected_entries[len]);
678+
if (entries.empty())
635679
continue;
636680
size_t end_pos = start_pos + len;
637-
// create a new sentence
638-
auto new_sentence = New<Sentence>(*sentences[start_pos]);
639-
new_sentence->Extend(*entries[len], end_pos);
640-
// compare and update sentences
641-
if (sentences.find(end_pos) == sentences.end() ||
642-
sentences[end_pos]->weight() <= new_sentence->weight()) {
643-
sentences[end_pos] = std::move(new_sentence);
681+
bool is_rear = end_pos == input.length();
682+
for (const auto& entry : entries) {
683+
// create a new sentence
684+
auto new_sentence = New<Sentence>(*sentences[start_pos]);
685+
new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get());
686+
// compare and update sentences
687+
if (sentences.find(end_pos) == sentences.end() ||
688+
sentences[end_pos]->weight() <= new_sentence->weight()) {
689+
sentences[end_pos] = std::move(new_sentence);
690+
}
644691
}
645692
}
646693
}

0 commit comments

Comments
 (0)