Skip to content

Commit b3f4005

Browse files
committed
feat(poet): find best sentence candidates
1 parent 9934788 commit b3f4005

File tree

1 file changed

+61
-23
lines changed

1 file changed

+61
-23
lines changed

src/rime/gear/poet.cc

+61-23
Original file line numberDiff line numberDiff line change
@@ -41,43 +41,81 @@ bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) {
4141
other.syllable_lengths().end()))));
4242
}
4343

44+
// keep the best sentence candidate per last phrase
45+
using SentenceCandidates = hash_map<string, of<Sentence>>;
46+
47+
static vector<of<Sentence>> top_candidates(const SentenceCandidates& candidates,
48+
size_t n,
49+
Poet::Compare& compare) {
50+
vector<of<Sentence>> top;
51+
top.reserve(n + 1);
52+
for (const auto& candidate : candidates) {
53+
auto pos = std::upper_bound(
54+
top.begin(), top.end(), candidate.second,
55+
[&](const an<Sentence>& a, const an<Sentence>& b) {
56+
return !compare(*a, *b); // desc
57+
});
58+
if (pos - top.begin() >= n) continue;
59+
top.insert(pos, candidate.second);
60+
if (top.size() > n) top.pop_back();
61+
}
62+
return top;
63+
}
64+
65+
an<Sentence> find_best_sentence(const SentenceCandidates& candidates,
66+
Poet::Compare& compare) {
67+
an<Sentence> best = nullptr;
68+
for (const auto& candidate : candidates) {
69+
if (!best || compare(*best, *candidate.second)) {
70+
best = candidate.second;
71+
}
72+
}
73+
return best;
74+
}
75+
76+
constexpr int kMaxSentenceCandidates = 7;
77+
4478
an<Sentence> Poet::MakeSentence(const WordGraph& graph,
4579
size_t total_length,
4680
const string& preceding_text) {
47-
// TODO: save more intermediate sentence candidates
48-
map<int, an<Sentence>> sentences;
49-
sentences[0] = New<Sentence>(language_);
50-
// dynamic programming
81+
map<int, SentenceCandidates> sentences;
82+
sentences[0].emplace("", New<Sentence>(language_));
5183
for (const auto& w : graph) {
5284
size_t start_pos = w.first;
53-
DLOG(INFO) << "start pos: " << start_pos;
5485
if (sentences.find(start_pos) == sentences.end())
5586
continue;
56-
for (const auto& x : w.second) {
57-
size_t end_pos = x.first;
58-
if (start_pos == 0 && end_pos == total_length)
59-
continue; // exclude single words from the result
60-
DLOG(INFO) << "end pos: " << end_pos;
61-
bool is_rear = end_pos == total_length;
62-
const DictEntryList& entries(x.second);
63-
for (const auto& entry : entries) {
64-
auto new_sentence = New<Sentence>(*sentences[start_pos]);
65-
new_sentence->Extend(
66-
*entry, end_pos, is_rear, preceding_text, grammar_.get());
67-
if (sentences.find(end_pos) == sentences.end() ||
68-
compare_(*sentences[end_pos], *new_sentence)) {
69-
DLOG(INFO) << "updated sentences " << end_pos << ") with "
70-
<< new_sentence->text() << " weight: "
71-
<< new_sentence->weight();
72-
sentences[end_pos] = std::move(new_sentence);
87+
DLOG(INFO) << "start pos: " << start_pos;
88+
auto top = top_candidates(
89+
sentences[start_pos], kMaxSentenceCandidates, compare_);
90+
for (const auto& candidate : top) {
91+
for (const auto& x : w.second) {
92+
size_t end_pos = x.first;
93+
if (start_pos == 0 && end_pos == total_length)
94+
continue; // exclude single words from the result
95+
DLOG(INFO) << "end pos: " << end_pos;
96+
bool is_rear = end_pos == total_length;
97+
auto& target(sentences[end_pos]);
98+
const DictEntryList& entries(x.second);
99+
for (const auto& entry : entries) {
100+
auto new_sentence = New<Sentence>(*candidate);
101+
new_sentence->Extend(
102+
*entry, end_pos, is_rear, preceding_text, grammar_.get());
103+
const auto& key = new_sentence->components().back().text;
104+
auto& best_sentence = target[key];
105+
if (!best_sentence || compare_(*best_sentence, *new_sentence)) {
106+
DLOG(INFO) << "updated sentences " << end_pos << ") with "
107+
<< new_sentence->text() << " weight: "
108+
<< new_sentence->weight();
109+
best_sentence = std::move(new_sentence);
110+
}
73111
}
74112
}
75113
}
76114
}
77115
if (sentences.find(total_length) == sentences.end())
78116
return nullptr;
79117
else
80-
return sentences[total_length];
118+
return find_best_sentence(sentences[total_length], compare_);
81119
}
82120

83121
} // namespace rime

0 commit comments

Comments
 (0)