Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ability to use guide tokens for OuteTTS, greatly improving TTS recitation accuracy over long input sequences. #11186

Merged
merged 3 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.vocoder.model = value;
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--tts-use-guide-tokens"},
"Use guide tokens to improve TTS word recall",
[](common_params & params) {
params.vocoder.use_guide_tokens = true;
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));

// model-specific
add_opt(common_arg(
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ struct common_params_vocoder {

std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT

bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};

struct common_params {
Expand Down
45 changes: 44 additions & 1 deletion examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
}

static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
const std::string& delimiter = "<|text_sep|>";

std::vector<llama_token> result;
size_t start = 0;
size_t end = str.find(delimiter);

//first token is always a newline, as it was not previously added
result.push_back(common_tokenize(vocab, "\n", false, true)[0]);

while (end != std::string::npos) {
std::string current_word = str.substr(start, end - start);
auto tmp = common_tokenize(vocab, current_word, false, true);
result.push_back(tmp[0]);
start = end + delimiter.length();
end = str.find(delimiter, start);
}

// Add the last part
std::string current_word = str.substr(start);
auto tmp = common_tokenize(vocab, current_word, false, true);
if (tmp.size() > 0) {
result.push_back(tmp[0]);
}
return result;
}

int main(int argc, char ** argv) {
common_params params;

Expand Down Expand Up @@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us();

std::vector<llama_token> codes;
std::vector<llama_token> guide_tokens;

// process prompt and generate voice codes
{
Expand All @@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
}

LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());

Expand Down Expand Up @@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
int n_past = batch.n_tokens;
int n_decode = 0;

bool next_token_uses_guide_token = true;

while (n_decode <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
Expand All @@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
continue;
}

const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);

//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
llama_token guide_token = guide_tokens[0];
guide_tokens.erase(guide_tokens.begin());
new_token_id = guide_token; //ensure correct word fragment is used
}

//this is the token id that always precedes a new word
next_token_uses_guide_token = (new_token_id == 198);

common_sampler_accept(smpl[i], new_token_id, true);

Expand Down
Loading