Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: allow json array in prompt or content for direct token input #2306

Merged
merged 9 commits into from
Aug 23, 2023
2 changes: 1 addition & 1 deletion examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ node .

`stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.

`prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. A space is inserted in the front like main.cpp does.
`prompt`: Provide a prompt as a string, or as an array of strings and numbers representing tokens. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. If the prompt is a string, or an array with the first element given as a string, a space is inserted in the front like main.cpp does.

`stop`: Specify a JSON array of stopping strings.
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
Expand Down
80 changes: 73 additions & 7 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ struct llama_server_context
size_t n_past = 0;
size_t n_remain = 0;

json prompt;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;

Expand Down Expand Up @@ -267,6 +268,53 @@ struct llama_server_context
return true;
}

std::vector<llama_token> tokenize(json json_prompt, bool add_bos)
{
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;

if (json_prompt.is_array())
{
bool first = true;
for (const auto& p : json_prompt)
{
if (p.is_string())
{
auto s = p.template get<std::string>();
std::vector<llama_token> p;
if (first)
{
s.insert(0, 1, ' '); // add a space if it's the first
p = ::llama_tokenize(ctx, s, add_bos);
first = false;
}
else
{
p = ::llama_tokenize(ctx, s, false);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
}
else
{
if (first)
{
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
}
else
{
auto s = json_prompt.template get<std::string>();
s.insert(0, 1, ' '); // always add a first space
prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
}

return prompt_tokens;
}

bool loadGrammar()
{
if (!params.grammar.empty()) {
Expand Down Expand Up @@ -294,8 +342,8 @@ struct llama_server_context

void loadPrompt()
{
params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
auto prompt_tokens = tokenize(prompt, true); // always add BOS

num_prompt_tokens = prompt_tokens.size();

if (params.n_keep < 0)
Expand Down Expand Up @@ -1016,7 +1064,7 @@ static json format_final_response(llama_server_context &llama, const std::string
{"tokens_predicted", llama.num_tokens_predicted},
{"tokens_evaluated", llama.num_prompt_tokens},
{"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt},
{"prompt", llama.prompt},
{"truncated", llama.truncated},
{"stopped_eos", llama.stopped_eos},
{"stopped_word", llama.stopped_word},
Expand Down Expand Up @@ -1085,10 +1133,18 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
llama.params.seed = json_value(body, "seed", default_params.seed);
llama.params.prompt = json_value(body, "prompt", default_params.prompt);
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);

if (body.count("prompt") != 0)
{
llama.prompt = body["prompt"];
}
else
{
llama.prompt = "";
}

llama.params.logit_bias.clear();
if (json_value(body, "ignore_eos", false))
{
Expand Down Expand Up @@ -1345,8 +1401,11 @@ int main(int argc, char **argv)
auto lock = llama.lock();

const json body = json::parse(req.body);
const std::string content = json_value<std::string>(body, "content", "");
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
std::vector<llama_token> tokens;
if (body.count("content") != 0)
{
tokens = llama.tokenize(body["content"], false);
}
const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json"); });

Expand All @@ -1358,7 +1417,14 @@ int main(int argc, char **argv)

llama.rewind();
llama_reset_timings(llama.ctx);
llama.params.prompt = json_value<std::string>(body, "content", "");
if (body.count("content") != 0)
{
llama.prompt = body["content"];
}
else
{
llama.prompt = "";
}
llama.params.n_predict = 0;
llama.loadPrompt();
llama.beginCompletion();
Expand Down