@@ -41,43 +41,81 @@ bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) {
4141 other.syllable_lengths ().end ()))));
4242}
4343
44+ // keep the best sentence candidate per last phrase
45+ using SentenceCandidates = hash_map<string, of<Sentence>>;
46+
47+ static vector<of<Sentence>> top_candidates (const SentenceCandidates& candidates,
48+ size_t n,
49+ Poet::Compare& compare) {
50+ vector<of<Sentence>> top;
51+ top.reserve (n + 1 );
52+ for (const auto & candidate : candidates) {
53+ auto pos = std::upper_bound (
54+ top.begin (), top.end (), candidate.second ,
55+ [&](const an<Sentence>& a, const an<Sentence>& b) {
56+ return !compare (*a, *b); // desc
57+ });
58+ if (pos - top.begin () >= n) continue ;
59+ top.insert (pos, candidate.second );
60+ if (top.size () > n) top.pop_back ();
61+ }
62+ return top;
63+ }
64+
65+ an<Sentence> find_best_sentence (const SentenceCandidates& candidates,
66+ Poet::Compare& compare) {
67+ an<Sentence> best = nullptr ;
68+ for (const auto & candidate : candidates) {
69+ if (!best || compare (*best, *candidate.second )) {
70+ best = candidate.second ;
71+ }
72+ }
73+ return best;
74+ }
75+
76+ constexpr int kMaxSentenceCandidates = 7 ;
77+
4478an<Sentence> Poet::MakeSentence (const WordGraph& graph,
4579 size_t total_length,
4680 const string& preceding_text) {
47- // TODO: save more intermediate sentence candidates
48- map<int , an<Sentence>> sentences;
49- sentences[0 ] = New<Sentence>(language_);
50- // dynamic programming
81+ map<int , SentenceCandidates> sentences;
82+ sentences[0 ].emplace (" " , New<Sentence>(language_));
5183 for (const auto & w : graph) {
5284 size_t start_pos = w.first ;
53- DLOG (INFO) << " start pos: " << start_pos;
5485 if (sentences.find (start_pos) == sentences.end ())
5586 continue ;
56- for (const auto & x : w.second ) {
57- size_t end_pos = x.first ;
58- if (start_pos == 0 && end_pos == total_length)
59- continue ; // exclude single words from the result
60- DLOG (INFO) << " end pos: " << end_pos;
61- bool is_rear = end_pos == total_length;
62- const DictEntryList& entries (x.second );
63- for (const auto & entry : entries) {
64- auto new_sentence = New<Sentence>(*sentences[start_pos]);
65- new_sentence->Extend (
66- *entry, end_pos, is_rear, preceding_text, grammar_.get ());
67- if (sentences.find (end_pos) == sentences.end () ||
68- compare_ (*sentences[end_pos], *new_sentence)) {
69- DLOG (INFO) << " updated sentences " << end_pos << " ) with "
70- << new_sentence->text () << " weight: "
71- << new_sentence->weight ();
72- sentences[end_pos] = std::move (new_sentence);
87+ DLOG (INFO) << " start pos: " << start_pos;
88+ auto top = top_candidates (
89+ sentences[start_pos], kMaxSentenceCandidates , compare_);
90+ for (const auto & candidate : top) {
91+ for (const auto & x : w.second ) {
92+ size_t end_pos = x.first ;
93+ if (start_pos == 0 && end_pos == total_length)
94+ continue ; // exclude single words from the result
95+ DLOG (INFO) << " end pos: " << end_pos;
96+ bool is_rear = end_pos == total_length;
97+ auto & target (sentences[end_pos]);
98+ const DictEntryList& entries (x.second );
99+ for (const auto & entry : entries) {
100+ auto new_sentence = New<Sentence>(*candidate);
101+ new_sentence->Extend (
102+ *entry, end_pos, is_rear, preceding_text, grammar_.get ());
103+ const auto & key = new_sentence->components ().back ().text ;
104+ auto & best_sentence = target[key];
105+ if (!best_sentence || compare_ (*best_sentence, *new_sentence)) {
106+ DLOG (INFO) << " updated sentences " << end_pos << " ) with "
107+ << new_sentence->text () << " weight: "
108+ << new_sentence->weight ();
109+ best_sentence = std::move (new_sentence);
110+ }
73111 }
74112 }
75113 }
76114 }
77115 if (sentences.find (total_length) == sentences.end ())
78116 return nullptr ;
79117 else
80- return sentences[total_length];
118+ return find_best_sentence ( sentences[total_length], compare_) ;
81119}
82120
83121} // namespace rime
0 commit comments