Skip to content

Commit 0853465

Browse files
committed
perf(poet): optimize for performance in making sentences (~40% faster)
1 parent 44dd002 commit 0853465

7 files changed

+183
-121
lines changed

src/rime/gear/contextual_translation.cc

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <algorithm>
22
#include <iterator>
33
#include <rime/gear/contextual_translation.h>
4+
#include <rime/gear/grammar.h>
45
#include <rime/gear/translator_commons.h>
56

67
namespace rime {
@@ -37,12 +38,13 @@ bool ContextualTranslation::Replenish() {
3738
}
3839

3940
an<Phrase> ContextualTranslation::Evaluate(an<Phrase> phrase) {
40-
auto sentence = New<Sentence>(phrase->language());
41-
sentence->Offset(phrase->start());
4241
bool is_rear = phrase->end() == input_.length();
43-
sentence->Extend(phrase->entry(), phrase->end(), is_rear, preceding_text_,
44-
grammar_);
45-
phrase->set_weight(sentence->weight());
42+
double weight = Grammar::Evaluate(preceding_text_,
43+
phrase->text(),
44+
phrase->weight(),
45+
is_rear,
46+
grammar_);
47+
phrase->set_weight(weight);
4648
DLOG(INFO) << "contextual suggestion: " << phrase->text()
4749
<< " weight: " << phrase->weight();
4850
return phrase;

src/rime/gear/grammar.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include <rime/common.h>
55
#include <rime/component.h>
6-
#include <rime/dict/vocabulary.h>
76

87
namespace rime {
98

@@ -17,12 +16,13 @@ class Grammar : public Class<Grammar, Config*> {
1716
bool is_rear) = 0;
1817

1918
inline static double Evaluate(const string& context,
20-
const DictEntry& entry,
19+
const string& entry_text,
20+
double entry_weight,
2121
bool is_rear,
2222
Grammar* grammar) {
2323
const double kPenalty = -18.420680743952367; // log(1e-8)
24-
return entry.weight +
25-
(grammar ? grammar->Query(context, entry.text, is_rear) : kPenalty);
24+
return entry_weight +
25+
(grammar ? grammar->Query(context, entry_text, is_rear) : kPenalty);
2626
}
2727
};
2828

src/rime/gear/poet.cc

+152-82
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,64 @@
1616

1717
namespace rime {
1818

19+
// internal data structure used during the sentence making process.
20+
// the output line of the algorithm is transformed to an<Sentence>.
21+
struct Line {
22+
// be sure the pointer to predecessor Line object is stable. it works since
23+
// pointer to values stored in std::map and std::unordered_map are stable.
24+
const Line* predecessor;
25+
// as long as the word graph lives, pointers to entries are valid.
26+
const DictEntry* entry;
27+
size_t end_pos;
28+
double weight;
29+
30+
static const Line kEmpty;
31+
32+
bool empty() const {
33+
return !predecessor && !entry;
34+
}
35+
36+
string last_word() const {
37+
return entry ? entry->text : string();
38+
}
39+
40+
struct Components {
41+
vector<const Line*> lines;
42+
43+
Components(const Line* line) {
44+
for (const Line* cursor = line;
45+
!cursor->empty();
46+
cursor = cursor->predecessor) {
47+
lines.push_back(cursor);
48+
}
49+
}
50+
51+
decltype(lines.crbegin()) begin() const { return lines.crbegin(); }
52+
decltype(lines.crend()) end() const { return lines.crend(); }
53+
};
54+
55+
Components components() const { return Components(this); }
56+
57+
string context() const {
58+
// look back 2 words
59+
return empty() ? string() :
60+
!predecessor || predecessor->empty() ? last_word() :
61+
predecessor->last_word() + last_word();
62+
}
63+
64+
vector<size_t> word_lengths() const {
65+
vector<size_t> lengths;
66+
size_t last_end_pos = 0;
67+
for (const auto* c : components()) {
68+
lengths.push_back(c->end_pos - last_end_pos);
69+
last_end_pos = c->end_pos;
70+
}
71+
return lengths;
72+
}
73+
};
74+
75+
const Line Line::kEmpty{nullptr, nullptr, 0, 0.0};
76+
1977
inline static Grammar* create_grammar(Config* config) {
2078
if (auto* grammar = Grammar::Require("grammar")) {
2179
return grammar->Create(config);
@@ -30,102 +88,103 @@ Poet::Poet(const Language* language, Config* config, Compare compare)
3088

3189
Poet::~Poet() {}
3290

33-
bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) {
34-
return one.weight() < other.weight() || ( // left associate if even
35-
one.weight() == other.weight() && (
36-
one.size() > other.size() || ( // less components is more favorable
37-
one.size() == other.size() &&
38-
std::lexicographical_compare(one.syllable_lengths().begin(),
39-
one.syllable_lengths().end(),
40-
other.syllable_lengths().begin(),
41-
other.syllable_lengths().end()))));
91+
bool Poet::CompareWeight(const Line& one, const Line& other) {
92+
return one.weight < other.weight;
93+
}
94+
95+
// returns true if one is less than other.
96+
bool Poet::LeftAssociateCompare(const Line& one, const Line& other) {
97+
if (one.weight < other.weight) return true;
98+
if (one.weight == other.weight) {
99+
auto one_word_lens = one.word_lengths();
100+
auto other_word_lens = other.word_lengths();
101+
// less words is more favorable
102+
if (one_word_lens.size() > other_word_lens.size()) return true;
103+
if (one_word_lens.size() == other_word_lens.size()) {
104+
return std::lexicographical_compare(
105+
one_word_lens.begin(), one_word_lens.end(),
106+
other_word_lens.begin(), other_word_lens.end());
107+
}
108+
}
109+
return false;
42110
}
43111

44-
// keep the best sentence candidate per last phrase
45-
using SentenceCandidates = hash_map<string, of<Sentence>>;
112+
// keep the best line candidate per last phrase
113+
using LineCandidates = hash_map<string, Line>;
46114

47115
template <int N>
48-
static vector<of<Sentence>> find_top_candidates(
49-
const SentenceCandidates& candidates, Poet::Compare compare) {
50-
vector<of<Sentence>> top;
116+
static vector<const Line*> find_top_candidates(
117+
const LineCandidates& candidates, Poet::Compare compare) {
118+
vector<const Line*> top;
51119
top.reserve(N + 1);
52120
for (const auto& candidate : candidates) {
53121
auto pos = std::upper_bound(
54-
top.begin(), top.end(), candidate.second,
55-
[&](const an<Sentence>& a, const an<Sentence>& b) {
56-
return !compare(*a, *b); // desc
57-
});
122+
top.begin(), top.end(), &candidate.second,
123+
[&](const Line* a, const Line* b) { return compare(*b, *a); }); // desc
58124
if (pos - top.begin() >= N) continue;
59-
top.insert(pos, candidate.second);
125+
top.insert(pos, &candidate.second);
60126
if (top.size() > N) top.pop_back();
61127
}
62128
return top;
63129
}
64130

65-
static an<Sentence> find_best_sentence(const SentenceCandidates& candidates,
66-
Poet::Compare compare) {
67-
an<Sentence> best = nullptr;
68-
for (const auto& candidate : candidates) {
69-
if (!best || compare(*best, *candidate.second)) {
70-
best = candidate.second;
71-
}
72-
}
73-
return best;
74-
}
75-
76-
using UpdateSetenceCandidate = function<void (const an<Sentence>& candidate)>;
131+
using UpdateLineCandidate = function<void (const Line& candidate)>;
77132

78133
struct BeamSearch {
79-
using State = SentenceCandidates;
134+
using State = LineCandidates;
80135

81-
static constexpr int kMaxSentenceCandidates = 7;
136+
static constexpr int kMaxLineCandidates = 7;
82137

83-
static void Initiate(State& initial_state, const Language* language) {
84-
initial_state.emplace("", New<Sentence>(language));
138+
static void Initiate(State& initial_state) {
139+
initial_state.emplace("", Line::kEmpty);
85140
}
86141

87142
static void ForEachCandidate(const State& state,
88143
Poet::Compare compare,
89-
UpdateSetenceCandidate update) {
144+
UpdateLineCandidate update) {
90145
auto top_candidates =
91-
find_top_candidates<kMaxSentenceCandidates>(state, compare);
92-
for (const auto& candidate : top_candidates) {
93-
update(candidate);
146+
find_top_candidates<kMaxLineCandidates>(state, compare);
147+
for (const auto* candidate : top_candidates) {
148+
update(*candidate);
94149
}
95150
}
96151

97-
static an<Sentence>& BestSentenceToUpdate(State& state,
98-
const an<Sentence>& new_sentence) {
99-
const auto& key = new_sentence->components().back().text;
152+
static Line& BestLineToUpdate(State& state, const Line& new_line) {
153+
const auto& key = new_line.last_word();
100154
return state[key];
101155
}
102156

103-
static an<Sentence> BestSentence(const State& final_state,
104-
Poet::Compare compare) {
105-
return find_best_sentence(final_state, compare);
157+
static const Line& BestLineInState(const State& final_state,
158+
Poet::Compare compare) {
159+
const Line* best = nullptr;
160+
for (const auto& candidate : final_state) {
161+
if (!best || compare(*best, candidate.second)) {
162+
best = &candidate.second;
163+
}
164+
}
165+
return best ? *best : Line::kEmpty;
106166
}
107167
};
108168

109169
struct DynamicProgramming {
110-
using State = an<Sentence>;
170+
using State = Line;
111171

112-
static void Initiate(State& initial_state, const Language* language) {
113-
initial_state = New<Sentence>(language);
172+
static void Initiate(State& initial_state) {
173+
initial_state = Line::kEmpty;
114174
}
115175

116176
static void ForEachCandidate(const State& state,
117177
Poet::Compare compare,
118-
UpdateSetenceCandidate update) {
178+
UpdateLineCandidate update) {
119179
update(state);
120180
}
121181

122-
static an<Sentence>& BestSentenceToUpdate(State& state,
123-
const an<Sentence>& new_sentence) {
182+
static Line& BestLineToUpdate(State& state, const Line& new_line) {
124183
return state;
125184
}
126185

127-
static an<Sentence> BestSentence(const State& final_state,
128-
Poet::Compare compare) {
186+
static const Line& BestLineInState(const State& final_state,
187+
Poet::Compare compare) {
129188
return final_state;
130189
}
131190
};
@@ -134,47 +193,58 @@ template <class Strategy>
134193
an<Sentence> Poet::MakeSentenceWithStrategy(const WordGraph& graph,
135194
size_t total_length,
136195
const string& preceding_text) {
137-
map<int, typename Strategy::State> sentences;
138-
Strategy::Initiate(sentences[0], language_);
139-
for (const auto& w : graph) {
140-
size_t start_pos = w.first;
141-
if (sentences.find(start_pos) == sentences.end())
196+
map<int, typename Strategy::State> states;
197+
Strategy::Initiate(states[0]);
198+
for (const auto& sv : graph) {
199+
size_t start_pos = sv.first;
200+
if (states.find(start_pos) == states.end())
142201
continue;
143202
DLOG(INFO) << "start pos: " << start_pos;
144-
const auto& source(sentences[start_pos]);
145-
Strategy::ForEachCandidate(
146-
source, compare_,
147-
[&](const an<Sentence>& candidate) {
148-
for (const auto& x : w.second) {
149-
size_t end_pos = x.first;
203+
const auto& source_state = states[start_pos];
204+
const auto update =
205+
[this, &states, &sv, start_pos, total_length, &preceding_text]
206+
(const Line& candidate) {
207+
for (const auto& ev : sv.second) {
208+
size_t end_pos = ev.first;
150209
if (start_pos == 0 && end_pos == total_length)
151-
continue; // exclude single words from the result
210+
continue; // exclude single word from the result
152211
DLOG(INFO) << "end pos: " << end_pos;
153212
bool is_rear = end_pos == total_length;
154-
auto& target(sentences[end_pos]);
213+
auto& target_state = states[end_pos];
155214
// extend candidates with dict entries on a valid edge.
156-
const DictEntryList& entries(x.second);
215+
const DictEntryList& entries = ev.second;
157216
for (const auto& entry : entries) {
158-
auto new_sentence = New<Sentence>(*candidate);
159-
new_sentence->Extend(
160-
*entry, end_pos, is_rear, preceding_text, grammar_.get());
161-
auto& best_sentence =
162-
Strategy::BestSentenceToUpdate(target, new_sentence);
163-
if (!best_sentence || compare_(*best_sentence, *new_sentence)) {
164-
DLOG(INFO) << "updated sentences " << end_pos << ") with "
165-
<< new_sentence->text() << " weight: "
166-
<< new_sentence->weight();
167-
best_sentence = std::move(new_sentence);
217+
const string& context =
218+
candidate.empty() ? preceding_text : candidate.context();
219+
double weight = candidate.weight +
220+
Grammar::Evaluate(context,
221+
entry->text,
222+
entry->weight,
223+
is_rear,
224+
grammar_.get());
225+
Line new_line{&candidate, entry.get(), end_pos, weight};
226+
Line& best = Strategy::BestLineToUpdate(target_state, new_line);
227+
if (best.empty() || compare_(best, new_line)) {
228+
DLOG(INFO) << "updated line ending at " << end_pos
229+
<< " with text: ..." << new_line.last_word()
230+
<< " weight: " << new_line.weight;
231+
best = new_line;
168232
}
169233
}
170234
}
171-
});
235+
};
236+
Strategy::ForEachCandidate(source_state, compare_, update);
172237
}
173-
auto found = sentences.find(total_length);
174-
if (found == sentences.end())
238+
auto found = states.find(total_length);
239+
if (found == states.end() || found->second.empty())
175240
return nullptr;
176-
else
177-
return Strategy::BestSentence(found->second, compare_);
241+
const Line& best = Strategy::BestLineInState(found->second, compare_);
242+
auto sentence = New<Sentence>(language_);
243+
for (const auto* c : best.components()) {
244+
if (!c->entry) continue;
245+
sentence->Extend(*c->entry, c->end_pos, c->weight);
246+
}
247+
return sentence;
178248
}
179249

180250
an<Sentence> Poet::MakeSentence(const WordGraph& graph,

src/rime/gear/poet.h

+5-6
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,15 @@ using WordGraph = map<int, UserDictEntryCollector>;
2222

2323
class Grammar;
2424
class Language;
25+
struct Line;
2526

2627
class Poet {
2728
public:
28-
// sentence "less", used to compare sentences of the same input range.
29-
using Compare = function<bool (const Sentence&, const Sentence&)>;
29+
// Line "less", used to compare composed line of the same input range.
30+
using Compare = function<bool (const Line&, const Line&)>;
3031

31-
static bool CompareWeight(const Sentence& one, const Sentence& other) {
32-
return one.weight() < other.weight();
33-
}
34-
static bool LeftAssociateCompare(const Sentence& one, const Sentence& other);
32+
static bool CompareWeight(const Line& one, const Line& other);
33+
static bool LeftAssociateCompare(const Line& one, const Line& other);
3534

3635
Poet(const Language* language, Config* config,
3736
Compare compare = CompareWeight);

0 commit comments

Comments
 (0)