Skip to content

Commit 7047f65

Browse files
LostRuinstinglou
authored andcommitted
tts : add guide tokens support (ggml-org#11186)
* Added the ability to use guide tokens for OuteTTS, greatly improving TTS recitation accuracy over long input sequences. * applied linting suggestions, updated to latest llama_vocab changes, added a safety check, added newline to guide token start
1 parent 60fb2da commit 7047f65

File tree

3 files changed

+53
-1
lines changed

3 files changed

+53
-1
lines changed

common/arg.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -2254,6 +2254,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22542254
params.vocoder.model = value;
22552255
}
22562256
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
2257+
add_opt(common_arg(
2258+
{"--tts-use-guide-tokens"},
2259+
"Use guide tokens to improve TTS word recall",
2260+
[](common_params & params) {
2261+
params.vocoder.use_guide_tokens = true;
2262+
}
2263+
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
22572264

22582265
// model-specific
22592266
add_opt(common_arg(

common/common.h

+2
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ struct common_params_vocoder {
184184

185185
std::string model = ""; // model path // NOLINT
186186
std::string model_url = ""; // model url to download // NOLINT
187+
188+
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
187189
};
188190

189191
struct common_params {

examples/tts/tts.cpp

+44-1
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
425425
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
426426
}
427427

428+
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
429+
const std::string& delimiter = "<|text_sep|>";
430+
431+
std::vector<llama_token> result;
432+
size_t start = 0;
433+
size_t end = str.find(delimiter);
434+
435+
//first token is always a newline, as it was not previously added
436+
result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
437+
438+
while (end != std::string::npos) {
439+
std::string current_word = str.substr(start, end - start);
440+
auto tmp = common_tokenize(vocab, current_word, false, true);
441+
result.push_back(tmp[0]);
442+
start = end + delimiter.length();
443+
end = str.find(delimiter, start);
444+
}
445+
446+
// Add the last part
447+
std::string current_word = str.substr(start);
448+
auto tmp = common_tokenize(vocab, current_word, false, true);
449+
if (tmp.size() > 0) {
450+
result.push_back(tmp[0]);
451+
}
452+
return result;
453+
}
454+
428455
int main(int argc, char ** argv) {
429456
common_params params;
430457

@@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
494521
const auto t_main_start = ggml_time_us();
495522

496523
std::vector<llama_token> codes;
524+
std::vector<llama_token> guide_tokens;
497525

498526
// process prompt and generate voice codes
499527
{
@@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
508536
// convert the input text into the necessary format expected by OuteTTS
509537
{
510538
std::string prompt_clean = process_text(params.prompt);
539+
if (params.vocoder.use_guide_tokens) {
540+
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
541+
}
511542

512543
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
513544

@@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
717748
int n_past = batch.n_tokens;
718749
int n_decode = 0;
719750

751+
bool next_token_uses_guide_token = true;
752+
720753
while (n_decode <= n_predict) {
721754
// prepare the next batch
722755
common_batch_clear(batch);
@@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
728761
continue;
729762
}
730763

731-
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
764+
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
765+
766+
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
767+
if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
768+
llama_token guide_token = guide_tokens[0];
769+
guide_tokens.erase(guide_tokens.begin());
770+
new_token_id = guide_token; //ensure correct word fragment is used
771+
}
772+
773+
//this is the token id that always precedes a new word
774+
next_token_uses_guide_token = (new_token_id == 198);
732775

733776
common_sampler_accept(smpl[i], new_token_id, true);
734777

0 commit comments

Comments
 (0)