Skip to content

Commit 0c06204

Browse files
authored
main : add --in-prefix-bos to prefix BOS to user inputs; keep EOS (ggml-org#2304)
* add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS The BOS precedes the string specified by `--in-prefix`. Model generated EOS is now kept in the context. It provides a way to strictly following the prompt format used in Llama-2-chat. The EOS handling also benefits some existing finetunes that uses EOS to mark the end of turn. * examples/common: move input_prefix_bos to other bools
1 parent 1fed755 commit 0c06204

File tree

3 files changed

+34
-17
lines changed

3 files changed

+34
-17
lines changed

examples/common.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
432432
exit(0);
433433
} else if (arg == "--random-prompt") {
434434
params.random_prompt = true;
435+
} else if (arg == "--in-prefix-bos") {
436+
params.input_prefix_bos = true;
435437
} else if (arg == "--in-prefix") {
436438
if (++i >= argc) {
437439
invalid_param = true;
@@ -517,6 +519,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
517519
fprintf(stdout, " not supported with --interactive or other interactive options\n");
518520
fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n");
519521
fprintf(stdout, " --random-prompt start with a randomized prompt.\n");
522+
fprintf(stdout, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n");
520523
fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
521524
fprintf(stdout, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
522525
fprintf(stdout, " -f FNAME, --file FNAME\n");

examples/common.h

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ struct gpt_params {
8282
bool interactive_first = false; // wait for user input immediately
8383
bool multiline_input = false; // reverse the usage of `\`
8484

85+
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
8586
bool instruct = false; // instruction mode (used for Alpaca models)
8687
bool penalize_nl = true; // consider newlines as a repeatable token
8788
bool perplexity = false; // compute perplexity over the prompt

examples/main/main.cpp

+30-17
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@ int main(int argc, char ** argv) {
325325
}
326326
}
327327

328+
if (params.input_prefix_bos) {
329+
fprintf(stderr, "Input prefix with BOS\n");
330+
}
331+
328332
if (!params.input_prefix.empty()) {
329333
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
330334
}
@@ -633,16 +637,6 @@ int main(int argc, char ** argv) {
633637
last_n_tokens.push_back(id);
634638
}
635639

636-
// replace end of text token with newline token when in interactive mode
637-
if (id == llama_token_eos() && params.interactive && !params.instruct) {
638-
id = llama_token_newline.front();
639-
if (params.antiprompt.size() != 0) {
640-
// tokenize and inject first reverse prompt
641-
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
642-
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
643-
}
644-
}
645-
646640
// add it to the context
647641
embd.push_back(id);
648642

@@ -708,11 +702,34 @@ int main(int argc, char ** argv) {
708702
}
709703
}
710704

705+
// deal with end of text token in interactive mode
706+
if (last_n_tokens.back() == llama_token_eos()) {
707+
if (params.interactive) {
708+
if (params.antiprompt.size() != 0) {
709+
// tokenize and inject first reverse prompt
710+
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
711+
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
712+
is_antiprompt = true;
713+
}
714+
715+
is_interacting = true;
716+
printf("\n");
717+
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
718+
fflush(stdout);
719+
} else if (params.instruct) {
720+
is_interacting = true;
721+
}
722+
}
723+
711724
if (n_past > 0 && is_interacting) {
712725
if (params.instruct) {
713726
printf("\n> ");
714727
}
715728

729+
if (params.input_prefix_bos) {
730+
embd_inp.push_back(llama_token_bos());
731+
}
732+
716733
std::string buffer;
717734
if (!params.input_prefix.empty()) {
718735
buffer += params.input_prefix;
@@ -776,13 +793,9 @@ int main(int argc, char ** argv) {
776793
}
777794

778795
// end of text token
779-
if (!embd.empty() && embd.back() == llama_token_eos()) {
780-
if (params.instruct) {
781-
is_interacting = true;
782-
} else {
783-
fprintf(stderr, " [end of text]\n");
784-
break;
785-
}
796+
if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) {
797+
fprintf(stderr, " [end of text]\n");
798+
break;
786799
}
787800

788801
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.

0 commit comments

Comments
 (0)