@@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
425
425
prompt_add (prompt, vocab, " <|im_start|>\n " , true , true );
426
426
}
427
427
428
+ static std::vector<llama_token> prepare_guide_tokens (const llama_vocab * vocab, const std::string & str) {
429
+ const std::string& delimiter = " <|text_sep|>" ;
430
+
431
+ std::vector<llama_token> result;
432
+ size_t start = 0 ;
433
+ size_t end = str.find (delimiter);
434
+
435
+ // first token is always a newline, as it was not previously added
436
+ result.push_back (common_tokenize (vocab, " \n " , false , true )[0 ]);
437
+
438
+ while (end != std::string::npos) {
439
+ std::string current_word = str.substr (start, end - start);
440
+ auto tmp = common_tokenize (vocab, current_word, false , true );
441
+ result.push_back (tmp[0 ]);
442
+ start = end + delimiter.length ();
443
+ end = str.find (delimiter, start);
444
+ }
445
+
446
+ // Add the last part
447
+ std::string current_word = str.substr (start);
448
+ auto tmp = common_tokenize (vocab, current_word, false , true );
449
+ if (tmp.size () > 0 ) {
450
+ result.push_back (tmp[0 ]);
451
+ }
452
+ return result;
453
+ }
454
+
428
455
int main (int argc, char ** argv) {
429
456
common_params params;
430
457
@@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
494
521
const auto t_main_start = ggml_time_us ();
495
522
496
523
std::vector<llama_token> codes;
524
+ std::vector<llama_token> guide_tokens;
497
525
498
526
// process prompt and generate voice codes
499
527
{
@@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
508
536
// convert the input text into the necessary format expected by OuteTTS
509
537
{
510
538
std::string prompt_clean = process_text (params.prompt );
539
+ if (params.vocoder .use_guide_tokens ) {
540
+ guide_tokens = prepare_guide_tokens (vocab, prompt_clean);
541
+ }
511
542
512
543
LOG_INF (" %s: prompt: '%s'\n " , __func__, prompt_clean.c_str ());
513
544
@@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
717
748
int n_past = batch.n_tokens ;
718
749
int n_decode = 0 ;
719
750
751
+ bool next_token_uses_guide_token = true ;
752
+
720
753
while (n_decode <= n_predict) {
721
754
// prepare the next batch
722
755
common_batch_clear (batch);
@@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
728
761
continue ;
729
762
}
730
763
731
- const llama_token new_token_id = common_sampler_sample (smpl[i], ctx_ttc, i_batch[i]);
764
+ llama_token new_token_id = common_sampler_sample (smpl[i], ctx_ttc, i_batch[i]);
765
+
766
+ // guide tokens help prevent hallucinations by forcing the TTS to use the correct word
767
+ if (!guide_tokens.empty () && next_token_uses_guide_token && !llama_vocab_is_control (vocab, new_token_id) && !llama_vocab_is_eog (vocab, new_token_id)) {
768
+ llama_token guide_token = guide_tokens[0 ];
769
+ guide_tokens.erase (guide_tokens.begin ());
770
+ new_token_id = guide_token; // ensure correct word fragment is used
771
+ }
772
+
773
+ // this is the token id that always precedes a new word
774
+ next_token_uses_guide_token = (new_token_id == 198 );
732
775
733
776
common_sampler_accept (smpl[i], new_token_id, true );
734
777
0 commit comments