diff --git a/examples/server/README.md b/examples/server/README.md index 4d97db2e480eb..77997f98d577c 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -126,7 +126,7 @@ node . `stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`. - `prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. A space is inserted in the front like main.cpp does. + `prompt`: Provide a prompt as a string, or as an array of strings and numbers representing tokens. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. If the prompt is a string, or an array with the first element given as a string, a space is inserted in the front like main.cpp does. `stop`: Specify a JSON array of stopping strings. These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []). diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e5bc52cd00624..1e6d10c1d79e9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -190,6 +190,7 @@ struct llama_server_context size_t n_past = 0; size_t n_remain = 0; + json prompt; std::vector embd; std::vector last_n_tokens; @@ -267,6 +268,53 @@ struct llama_server_context return true; } + std::vector tokenize(json json_prompt, bool add_bos) + { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + std::vector prompt_tokens; + + if (json_prompt.is_array()) + { + bool first = true; + for (const auto& p : json_prompt) + { + if (p.is_string()) + { + auto s = p.template get(); + std::vector p; + if (first) + { + s.insert(0, 1, ' '); // add a space if it's the first + p = ::llama_tokenize(ctx, s, add_bos); + first = false; + } + else + { + p = ::llama_tokenize(ctx, s, false); + } + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } + else + { + if (first) + { + first = false; + } + prompt_tokens.push_back(p.template get()); + } + } + } + else + { + auto s = json_prompt.template get(); + s.insert(0, 1, ' '); // always add a first space + prompt_tokens = ::llama_tokenize(ctx, s, add_bos); + } + + return prompt_tokens; + } + bool loadGrammar() { if (!params.grammar.empty()) { @@ -294,8 +342,8 @@ struct llama_server_context void loadPrompt() { - params.prompt.insert(0, 1, ' '); // always add a first space - std::vector prompt_tokens = ::llama_tokenize(ctx, params.prompt, true); + auto prompt_tokens = tokenize(prompt, true); // always add BOS + num_prompt_tokens = prompt_tokens.size(); if (params.n_keep < 0) @@ -1016,7 +1064,7 @@ static json format_final_response(llama_server_context &llama, const std::string {"tokens_predicted", llama.num_tokens_predicted}, {"tokens_evaluated", llama.num_prompt_tokens}, {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.params.prompt}, + {"prompt", llama.prompt}, {"truncated", llama.truncated}, {"stopped_eos", llama.stopped_eos}, {"stopped_word", llama.stopped_word}, @@ -1085,10 +1133,18 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl); llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); llama.params.seed = json_value(body, "seed", default_params.seed); - llama.params.prompt = json_value(body, "prompt", default_params.prompt); llama.params.grammar = json_value(body, "grammar", default_params.grammar); llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs); + if (body.count("prompt") != 0) + { + llama.prompt = body["prompt"]; + } + else + { + llama.prompt = ""; + } + llama.params.logit_bias.clear(); if (json_value(body, "ignore_eos", false)) { @@ -1345,8 +1401,11 @@ int main(int argc, char **argv) auto lock = llama.lock(); const json body = json::parse(req.body); - const std::string content = json_value(body, "content", ""); - const std::vector tokens = llama_tokenize(llama.ctx, content, false); + std::vector tokens; + if (body.count("content") != 0) + { + tokens = llama.tokenize(body["content"], false); + } const json data = format_tokenizer_response(tokens); return res.set_content(data.dump(), "application/json"); }); @@ -1358,7 +1417,14 @@ int main(int argc, char **argv) llama.rewind(); llama_reset_timings(llama.ctx); - llama.params.prompt = json_value(body, "content", ""); + if (body.count("content") != 0) + { + llama.prompt = body["content"]; + } + else + { + llama.prompt = ""; + } llama.params.n_predict = 0; llama.loadPrompt(); llama.beginCompletion();