From db97c8b19b0bed25429123fd4000fb137b07fd58 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 6 Dec 2024 15:01:12 +0100 Subject: [PATCH 1/7] server : (refactor) no more json in server_task input --- examples/server/server.cpp | 716 +++++++++--------- .../server/tests/unit/test_chat_completion.py | 5 + examples/server/utils.hpp | 16 +- 3 files changed, 372 insertions(+), 365 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 809fafa187add..d3e6c5e83e494 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -82,28 +82,6 @@ enum error_type { ERROR_TYPE_NOT_SUPPORTED, // custom error }; -struct server_task { - int id = -1; // to be filled by server_queue - int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL - - llama_tokens prompt_tokens; - server_task_type type; - - // TODO @ngxson : we should get rid of json type here - json data; - - server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION; - - // utility function - static std::unordered_set get_list_id(const std::vector & tasks) { - std::unordered_set ids(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - ids.insert(tasks[i].id); - } - return ids; - } -}; - struct slot_params { bool stream = true; bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt @@ -118,6 +96,7 @@ struct slot_params { std::vector antiprompt; bool timings_per_token = false; + bool ignore_eos = false; struct common_params_sampling sampling; struct common_params_speculative speculative; @@ -134,7 +113,7 @@ struct slot_params { std::string oaicompat_model; std::string oaicompat_cmpl_id; - json to_json() { + json to_json() const { std::vector samplers; samplers.reserve(sampling.samplers.size()); for (const auto & sampler : sampling.samplers) { @@ -172,7 +151,7 @@ struct slot_params { {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, {"stream", stream}, - //{"logit_bias", sampling.logit_bias}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, {"grammar", sampling.grammar}, @@ -186,6 +165,212 @@ struct slot_params { } }; +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) + + server_task_type type; + server_task_inf_type inf_type; + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + slot_params params; + llama_tokens prompt_tokens; + int id_selected_slot = -1; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + server_task( + server_task_type type, + server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION) : type(type), inf_type(inf_type) {} + + static slot_params params_from_json_cmpl( + const llama_model * model, + const common_params & params_base, + const json & data) { + slot_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + slot_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 2); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + params.sampling.grammar = json_schema_to_grammar(schema); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + } + + { + params.sampling.logit_bias.clear(); + params.ignore_eos = json_value(data, "ignore_eos", false); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_n_vocab(model); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(model, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + } + + { + const auto & samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + std::vector sampler_names; + for (const auto & name : *samplers) { + if (name.is_string()) { + sampler_names.emplace_back(name); + } + } + params.sampling.samplers = common_sampler_types_from_names(sampler_names, false); + } else if (samplers->is_string()){ + std::string sampler_string; + for (const auto & name : *samplers) { + sampler_string += name; + } + params.sampling.samplers = common_sampler_types_from_chars(sampler_string); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; + } + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + struct result_timings { int32_t prompt_n = -1; double prompt_ms; @@ -197,7 +382,7 @@ struct result_timings { double predicted_per_token_ms; double predicted_per_second; - json to_json() { + json to_json() const { return { {"prompt_n", prompt_n}, {"prompt_ms", prompt_ms}, @@ -861,6 +1046,22 @@ struct server_slot { return timings; } + json to_json() const { + json res = params.to_json(); + res["id"] = id; + res["id_task"] = id_task; + res["is_processing"] = is_processing(); + res["prompt_tokens"] = prompt_tokens; + res["next_token"] = { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + {"stopping_word", stopping_word}, + }; + return res; + } + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { size_t stop_pos = std::string::npos; @@ -978,9 +1179,7 @@ struct server_queue { // Add a new task to the end of the queue int post(server_task task, bool front = false) { std::unique_lock lock(mutex_tasks); - if (task.id == -1) { - task.id = id++; - } + GGML_ASSERT(task.id != -1); QUE_DBG("new task, id = %d, front = %d\n", task.id, front); if (front) { queue_tasks.push_front(std::move(task)); @@ -1458,104 +1657,14 @@ struct server_context { } bool launch_slot_with_task(server_slot & slot, const server_task & task) { - // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - slot_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - - const auto & data = task.data; + slot.reset(); + slot.id_task = task.id; + slot.inf_type = task.inf_type; + slot.index = task.index; + slot.params = std::move(task.params); + slot.prompt_tokens = std::move(task.prompt_tokens); - if (data.count("__oaicompat") != 0) { - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; - slot.params.oaicompat = true; - slot.params.oaicompat_chat = json_value(data, "__oaicompat_chat", false); - slot.params.oaicompat_model = json_value(data, "model", model_name); - slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string()); - } else { - slot.params.oaicompat = false; - } - - - // enabling this will output extra debug information in the HTTP responses from the server - slot.params.verbose = params_base.verbosity > 9; - slot.params.timings_per_token = json_value(data, "timings_per_token", false); - - slot.params.stream = json_value(data, "stream", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", true); - slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); - slot.params.n_indent = json_value(data, "n_indent", defaults.n_indent); - slot.params.n_keep = json_value(data, "n_keep", defaults.n_keep); - slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard); - //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement - slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - - slot.params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - slot.params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - slot.params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - slot.params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - slot.params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - slot.params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - slot.params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - slot.params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - slot.params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - slot.params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - slot.params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - slot.params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - slot.params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - slot.params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - slot.params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - slot.params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - slot.params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl); - slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - - slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); - slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); - slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); - - slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); - slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2); - slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); - - if (slot.params.sampling.dry_base < 1.0f) { - slot.params.sampling.dry_base = defaults.sampling.dry_base; - } - - // sequence breakers for DRY - { - // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format - // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 - - if (data.contains("dry_sequence_breakers")) { - slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); - if (slot.params.sampling.dry_sequence_breakers.empty()) { - send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - } - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { - send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); - return false; - } - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - slot.params.sampling.grammar = json_schema_to_grammar(schema); - } catch (const std::exception & e) { - send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST); - return false; - } - } else { - slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); - } + SLT_DBG(slot, "launching slot : %s\n", slot.to_json().dump().c_str()); if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? @@ -1563,78 +1672,8 @@ struct server_context { SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); } - { - slot.params.sampling.logit_bias.clear(); - - if (json_value(data, "ignore_eos", false) && has_eos_token) { - slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY}); - } - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_n_vocab(model); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - slot.params.sampling.logit_bias.push_back({tok, bias}); - } - } else if (el[0].is_string()) { - auto toks = common_tokenize(model, el[0].get(), false); - for (auto tok : toks) { - slot.params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - } - } - - { - slot.params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - slot.params.antiprompt.push_back(word); - } - } - } - } - - { - const auto & samplers = data.find("samplers"); - if (samplers != data.end()) { - if (samplers->is_array()) { - std::vector sampler_names; - for (const auto & name : *samplers) { - if (name.is_string()) { - sampler_names.emplace_back(name); - } - } - slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false); - } else if (samplers->is_string()){ - std::string sampler_string; - for (const auto & name : *samplers) { - sampler_string += name; - } - slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string); - } - } else { - slot.params.sampling.samplers = defaults.sampling.samplers; - } + if (slot.params.ignore_eos && has_eos_token) { + slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY}); } { @@ -2007,81 +2046,13 @@ struct server_context { // Functions to create new task(s) and receive result(s) // - // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s) - std::vector create_tasks_inference(json data, server_task_inf_type inf_type) { - std::vector tasks; - auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) { - SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size()); - server_task task; - task.id = queue_tasks.get_new_id(); - task.inf_type = inf_type; - task.type = SERVER_TASK_TYPE_INFERENCE; - task.data = task_data; - task.prompt_tokens = std::move(prompt_tokens); - tasks.push_back(std::move(task)); - }; - - static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts"; - if (!data.contains("prompt")) { - throw std::runtime_error(error_msg); - } - - // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread - bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL; - std::vector tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true); - switch (inf_type) { - case SERVER_TASK_INF_TYPE_RERANK: - { - // prompts[0] is the question - // the rest are the answers/documents - GGML_ASSERT(tokenized_prompts.size() > 1); - SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1); - for (size_t i = 1; i < tokenized_prompts.size(); i++) { - data["index"] = i - 1; - auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]); - create_task(data, tokens); - } - } break; - case SERVER_TASK_INF_TYPE_INFILL: - { - SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - data["index"] = i; - auto tokens = format_infill( - ctx, - data.at("input_prefix"), - data.at("input_suffix"), - data.at("input_extra"), - params_base.n_batch, - params_base.n_predict, - slots[0].n_ctx, // TODO: there should be a better way - params_base.spm_infill, - tokenized_prompts[i] - ); - create_task(data, tokens); - } - } break; - default: - { - SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - data["index"] = i; - create_task(data, tokenized_prompts[i]); - } - } - } - - return tasks; - } - void cancel_tasks(const std::unordered_set & id_tasks) { std::vector cancel_tasks; cancel_tasks.reserve(id_tasks.size()); for (const auto & id_task : id_tasks) { SRV_WRN("cancel task, id_task = %d\n", id_task); - server_task task; - task.type = SERVER_TASK_TYPE_CANCEL; + server_task task(SERVER_TASK_TYPE_CANCEL); task.id_target = id_task; cancel_tasks.push_back(task); queue_results.remove_waiting_task_id(id_task); @@ -2090,7 +2061,7 @@ struct server_context { queue_tasks.post(cancel_tasks, true); } - // receive the results from task(s) created by create_tasks_inference + // receive the results from task(s) void receive_multi_results( const std::unordered_set & id_tasks, const std::function&)> & result_handler, @@ -2117,7 +2088,7 @@ struct server_context { result_handler(results); } - // receive the results from task(s) created by create_tasks_inference, in stream mode + // receive the results from task(s), in stream mode void receive_cmpl_results_stream( const std::unordered_set & id_tasks, const std::function & result_handler, @@ -2154,7 +2125,7 @@ struct server_context { switch (task.type) { case SERVER_TASK_TYPE_INFERENCE: { - const int id_slot = json_value(task.data, "id_slot", -1); + const int id_slot = task.id_selected_slot; server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); @@ -2171,13 +2142,6 @@ struct server_context { break; } - slot->reset(); - - slot->id_task = task.id; - slot->inf_type = task.inf_type; - slot->index = json_value(task.data, "index", 0); - slot->prompt_tokens = std::move(task.prompt_tokens); - if (!launch_slot_with_task(*slot, task)) { SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); break; @@ -2205,18 +2169,8 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = slot.params.to_json(); - slot_data["id"] = slot.id; - slot_data["id_task"] = slot.id_task; - slot_data["is_processing"] = slot.is_processing(); - slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens); - slot_data["next_token"] = { - {"has_next_token", slot.has_next_token}, - {"has_new_line", slot.has_new_line}, - {"n_remain", slot.n_remaining}, - {"n_decoded", slot.n_decoded}, - {"stopping_word", slot.stopping_word}, - }; + json slot_data = slot.to_json(); + slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens); if (slot.is_processing()) { n_processing_slots++; @@ -2251,14 +2205,14 @@ struct server_context { res->n_decode_total = metrics.n_decode_total; res->n_busy_slots_total = metrics.n_busy_slots_total; - if (json_value(task.data, "reset_bucket", false)) { + if (task.metrics_reset_bucket) { metrics.reset_bucket(); } queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_SAVE: { - int id_slot = task.data.at("id_slot"); + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); @@ -2274,8 +2228,8 @@ struct server_context { const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); @@ -2294,7 +2248,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { - int id_slot = task.data.at("id_slot"); + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); @@ -2309,8 +2263,8 @@ struct server_context { const int64_t t_start = ggml_time_us(); - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; slot->cache_tokens.resize(slot->n_ctx); size_t token_count = 0; @@ -2337,7 +2291,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { - int id_slot = task.data.at("id_slot"); + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); @@ -2396,10 +2350,8 @@ struct server_context { { SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - server_task task; - task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; - task.id_target = -1; - + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); queue_tasks.post(task); } @@ -3136,10 +3088,8 @@ int main(int argc, char ** argv) { } // request slots data using task queue - server_task task; + server_task task(SERVER_TASK_TYPE_METRICS); task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task, true); // high-priority task @@ -3174,11 +3124,9 @@ int main(int argc, char ** argv) { } // request slots data using task queue - server_task task; + server_task task(SERVER_TASK_TYPE_METRICS); task.id = ctx_server.queue_tasks.get_new_id(); - task.id_target = -1; - task.type = SERVER_TASK_TYPE_METRICS; - task.data.push_back({{"reset_bucket", true}}); + task.metrics_reset_bucket = true; ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task, true); // high-priority task @@ -3282,19 +3230,17 @@ int main(int argc, char ** argv) { } std::string filepath = params.slot_save_path + filename; - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_SAVE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath }, - }; + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = ctx_server.queue_tasks.get_new_id(); + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result_ptr result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { res_error(res, result->to_json()); @@ -3313,19 +3259,17 @@ int main(int argc, char ** argv) { } std::string filepath = params.slot_save_path + filename; - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_RESTORE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath }, - }; + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = ctx_server.queue_tasks.get_new_id(); + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result_ptr result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { res_error(res, result->to_json()); @@ -3337,17 +3281,15 @@ int main(int argc, char ** argv) { }; const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_ERASE; - task.data = { - { "id_slot", id_slot }, - }; + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = ctx_server.queue_tasks.get_new_id(); + task.slot_action.slot_id = id_slot; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result_ptr result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { res_error(res, result->to_json()); @@ -3416,14 +3358,41 @@ int main(int argc, char ** argv) { server_task_inf_type inf_type, json & data, httplib::Response & res, - bool oai_compat = false) { + bool oaicompat = false, + bool oaicompat_chat = false) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - data["completion_id"] = gen_chatcmplid(); - std::vector tasks = ctx_server.create_tasks_inference(data, inf_type); + auto completion_id = gen_chatcmplid(); + std::vector tasks; + + try { + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true); + tasks.reserve(tokenized_prompts.size()); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, inf_type); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(tokenized_prompts[i]); + task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.oaicompat_chat = oaicompat_chat; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(task); + } + } catch (const std::exception & e) { + res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3449,7 +3418,7 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) { + const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { json res_json = result->to_json(); if (res_json.is_array()) { @@ -3465,7 +3434,7 @@ int main(int argc, char ** argv) { }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); - if (oai_compat) { + if (oaicompat) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); } @@ -3483,7 +3452,12 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res); + return handle_completions_generic( + SERVER_TASK_INF_TYPE_COMPLETION, + data, + res, + /* oaicompat */ false, + /* oaicompat_chat */ false); }; const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -3543,8 +3517,12 @@ int main(int argc, char ** argv) { } json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - data["__oaicompat_chat"] = true; - return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true); + return handle_completions_generic( + SERVER_TASK_INF_TYPE_COMPLETION, + data, + res, + /* oaicompat */ true, + /* oaicompat_chat */ true); }; const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { @@ -3638,7 +3616,16 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING); + std::vector tasks; + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_EMBEDDING); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = std::move(tokenized_prompts[i]); + tasks.push_back(task); + } + ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3665,7 +3652,7 @@ int main(int argc, char ** argv) { // write JSON response json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) - : responses[0]; + : responses.size() == 1 ? responses[0] : json(responses); res_ok(res, root); }; @@ -3704,20 +3691,23 @@ int main(int argc, char ** argv) { return; } - // construct prompt object: array of ["query", "doc0", "doc1", ...] - json prompt; - prompt.push_back(query); - for (const auto & doc : documents) { - prompt.push_back(doc); - } - - LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str()); + llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0]; // create and queue the task json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK); + std::vector tasks; + std::vector tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true); + tasks.reserve(tokenized_docs.size()); + for (size_t i = 0; i < tokenized_docs.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]); + tasks.push_back(task); + } + ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3778,13 +3768,13 @@ int main(int argc, char ** argv) { } } - server_task task; - task.type = SERVER_TASK_TYPE_SET_LORA; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = ctx_server.queue_tasks.get_new_id(); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result_ptr result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { res_error(res, result->to_json()); diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index f13c6c4ca4bd3..d4ee360f82642 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte ], }) assert res.status_code == 200 + assert "cmpl" in res.body["id"] assert res.body["model"] == model if model is not None else server.model_alias assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted @@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte "stream": True, }) content = "" + last_cmpl_id = None for data in res: choice = data["choices"][0] assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future + if last_cmpl_id is None: + last_cmpl_id = data["id"] + assert last_cmpl_id == data["id"] if choice["finish_reason"] in ["stop", "length"]: assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["completion_tokens"] == n_predicted diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index a96116ac36caa..24673abc84011 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -164,6 +164,9 @@ static std::vector tokenize_input_prompts(llama_context * ctx, con } else { throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } return result; } @@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse( const std::string & chat_template) { json llama_params; - llama_params["__oaicompat"] = true; - // Apply chat template to the list of messages llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); @@ -648,3 +649,14 @@ static json format_detokenized_response(const std::string & content) { {"content", content} }; } + +static json format_logit_bias(const std::vector & logit_bias) { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); + } + return data; +} From 9bb1ae6beacf652cd6a10189571bfc21b8380310 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 7 Dec 2024 13:56:14 +0100 Subject: [PATCH 2/7] add test for slots endpoint --- examples/server/server.cpp | 1 + examples/server/tests/unit/test_basic.py | 11 +++++++++++ examples/server/tests/unit/test_chat_completion.py | 4 ++-- examples/server/tests/utils.py | 3 +++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d3e6c5e83e494..745df6b5c8bc2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2184,6 +2184,7 @@ struct server_context { auto res = std::make_unique(); res->id = task.id; + res->slots_data = slots_data; res->n_idle_slots = n_idle_slots; res->n_processing_slots = n_processing_slots; res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py index d82d54a5a6f47..b8d68989ba0b8 100644 --- a/examples/server/tests/unit/test_basic.py +++ b/examples/server/tests/unit/test_basic.py @@ -33,6 +33,17 @@ def test_server_models(): assert len(res.body["data"]) == 1 assert res.body["data"][0]["id"] == server.model_alias + +def test_server_slots(): + global server + server.server_slots = True + server.start() + res = server.make_request("GET", "/slots") + assert res.status_code == 200 + assert len(res.body) == server.n_slots + assert res.body[0]["n_ctx"] > 0 + + def test_load_split_model(): global server server.model_hf_repo = "ggml-org/models" diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index d4ee360f82642..6573cc17f7b87 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -30,7 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte ], }) assert res.status_code == 200 - assert "cmpl" in res.body["id"] + assert "cmpl" in res.body["id"] # make sure the completion id has the expected format assert res.body["model"] == model if model is not None else server.model_alias assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted @@ -66,7 +66,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future if last_cmpl_id is None: last_cmpl_id = data["id"] - assert last_cmpl_id == data["id"] + assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream if choice["finish_reason"] in ["stop", "length"]: assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["completion_tokens"] == n_predicted diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index e17a05ff6902a..c352943f25133 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -64,6 +64,7 @@ class ServerProcess: server_embeddings: bool | None = False server_reranking: bool | None = False server_metrics: bool | None = False + server_slots: bool | None = False draft: int | None = None api_key: str | None = None response_format: str | None = None @@ -129,6 +130,8 @@ def start(self, timeout_seconds: int = 10) -> None: server_args.append("--reranking") if self.server_metrics: server_args.append("--metrics") + if self.server_slots: + server_args.append("--slots") if self.model_alias: server_args.extend(["--alias", self.model_alias]) if self.n_ctx: From e721f4c6b443395e168e05ca96b46fbd424d7635 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 7 Dec 2024 19:17:35 +0100 Subject: [PATCH 3/7] add tests for /props and /slots --- examples/server/tests/unit/test_basic.py | 20 +++++++++++++++++++- examples/server/tests/utils.py | 7 +++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py index b8d68989ba0b8..22c6fe545903d 100644 --- a/examples/server/tests/unit/test_basic.py +++ b/examples/server/tests/unit/test_basic.py @@ -23,6 +23,10 @@ def test_server_props(): res = server.make_request("GET", "/props") assert res.status_code == 200 assert res.body["total_slots"] == server.n_slots + default_val = res.body["default_generation_settings"] + assert server.n_ctx is not None and server.n_slots is not None + assert default_val["n_ctx"] == server.n_ctx / server.n_slots + assert default_val["params"]["seed"] == server.seed def test_server_models(): @@ -36,12 +40,26 @@ def test_server_models(): def test_server_slots(): global server + + # without slots endpoint enabled, this should return error + server.server_slots = False + server.start() + res = server.make_request("GET", "/slots") + assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED + assert "error" in res.body + server.stop() + + # with slots endpoint enabled, this should return slots info server.server_slots = True + server.n_slots = 2 server.start() res = server.make_request("GET", "/slots") assert res.status_code == 200 assert len(res.body) == server.n_slots - assert res.body[0]["n_ctx"] > 0 + assert server.n_ctx is not None and server.n_slots is not None + assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots + assert "params" in res.body[0] + assert res.body[0]["params"]["seed"] == server.seed def test_load_split_model(): diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index c352943f25133..69215eaa4ebb7 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -92,7 +92,6 @@ def start(self, timeout_seconds: int = 10) -> None: else: server_path = "../../../build/bin/llama-server" server_args = [ - "--slots", # requires to get slot status via /slots endpoint "--host", self.server_host, "--port", @@ -184,7 +183,7 @@ def start(self, timeout_seconds: int = 10) -> None: start_time = time.time() while time.time() - start_time < timeout_seconds: try: - response = self.make_request("GET", "/slots", headers={ + response = self.make_request("GET", "/health", headers={ "Authorization": f"Bearer {self.api_key}" if self.api_key else None }) if response.status_code == 200: @@ -227,7 +226,7 @@ def make_request( result.headers = dict(response.headers) result.status_code = response.status_code result.body = response.json() if parse_body else None - print("Response from server", result.body) + print("Response from server", json.dumps(result.body, indent=2)) return result def make_stream_request( @@ -248,7 +247,7 @@ def make_stream_request( break elif line.startswith('data: '): data = json.loads(line[6:]) - print("Partial response from server", data) + print("Partial response from server", json.dumps(data, indent=2)) yield data From 090a113417117b90c7c8a946316caa6d9cba7544 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 7 Dec 2024 19:33:40 +0100 Subject: [PATCH 4/7] remove task inf_type --- examples/server/server.cpp | 68 ++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 01e9a6da1cd5e..6e97ff7d07141 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -54,7 +54,10 @@ enum server_state { }; enum server_task_type { - SERVER_TASK_TYPE_INFERENCE, + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_METRICS, @@ -64,13 +67,6 @@ enum server_task_type { SERVER_TASK_TYPE_SET_LORA, }; -enum server_task_inf_type { - SERVER_TASK_INF_TYPE_COMPLETION, - SERVER_TASK_INF_TYPE_EMBEDDING, - SERVER_TASK_INF_TYPE_RERANK, - SERVER_TASK_INF_TYPE_INFILL, -}; - // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 enum error_type { ERROR_TYPE_INVALID_REQUEST, @@ -163,8 +159,7 @@ struct server_task { int id = -1; // to be filled by server_queue int index = -1; // used when there are multiple prompts (batch request) - server_task_type type; - server_task_inf_type inf_type; + server_task_type type; // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; @@ -185,9 +180,7 @@ struct server_task { // used by SERVER_TASK_TYPE_METRICS bool metrics_reset_bucket = false; - server_task( - server_task_type type, - server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION) : type(type), inf_type(inf_type) {} + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( const llama_model * model, @@ -893,6 +886,9 @@ struct server_slot { int id; int id_task = -1; + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + llama_batch batch_spec = {}; llama_context * ctx = nullptr; @@ -931,8 +927,6 @@ struct server_slot { llama_tokens cache_tokens; std::vector generated_token_probs; - server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION; - bool has_next_token = true; bool has_new_line = false; bool truncated = false; @@ -972,11 +966,15 @@ struct server_slot { n_past = 0; n_sent_text = 0; n_sent_token_probs = 0; - inf_type = SERVER_TASK_INF_TYPE_COMPLETION; + task_type = SERVER_TASK_TYPE_COMPLETION; generated_token_probs.clear(); } + bool is_non_causal() const { + return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; + } + bool has_budget(const common_params & global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless @@ -1088,6 +1086,7 @@ struct server_slot { {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, + {"non_causal", is_non_causal()}, {"params", params.to_json()}, {"prompt", common_detokenize(ctx, prompt_tokens)}, {"next_token", @@ -1653,8 +1652,8 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, const server_task & task) { slot.reset(); slot.id_task = task.id; - slot.inf_type = task.inf_type; slot.index = task.index; + slot.task_type = task.type; slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); @@ -2120,7 +2119,10 @@ struct server_context { void process_single_task(server_task task) { switch (task.type) { - case SERVER_TASK_TYPE_INFERENCE: + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: { const int id_slot = task.id_selected_slot; @@ -2462,7 +2464,7 @@ struct server_context { continue; } - if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { + if (slot.is_non_causal()) { if (slot.n_prompt_tokens > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); @@ -2577,7 +2579,7 @@ struct server_context { } // non-causal tasks require to fit the entire prompt in the physical batch - if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { + if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2585,10 +2587,7 @@ struct server_context { } // check that we are in the right batch_type, if not defer the slot - const bool slot_type = - slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || - slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0; - + int slot_type = slot.is_non_causal(); if (batch_type == -1) { batch_type = slot_type; } else if (batch_type != slot_type) { @@ -2705,7 +2704,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -2713,7 +2712,7 @@ struct server_context { continue; // continue loop of slots } - if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -3352,11 +3351,13 @@ int main(int argc, char ** argv) { // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok]( - server_task_inf_type inf_type, + server_task_type type, json & data, httplib::Response & res, bool oaicompat = false, bool oaicompat_chat = false) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -3369,7 +3370,8 @@ int main(int argc, char ** argv) { std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, inf_type); + server_task task = server_task(type); + task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; @@ -3450,7 +3452,7 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); return handle_completions_generic( - SERVER_TASK_INF_TYPE_COMPLETION, + SERVER_TASK_TYPE_COMPLETION, data, res, /* oaicompat */ false, @@ -3504,7 +3506,7 @@ int main(int argc, char ** argv) { } data["input_extra"] = input_extra; // default to empty array if it's not exist - return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res); + return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -3515,7 +3517,7 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); return handle_completions_generic( - SERVER_TASK_INF_TYPE_COMPLETION, + SERVER_TASK_TYPE_COMPLETION, data, res, /* oaicompat */ true, @@ -3616,7 +3618,7 @@ int main(int argc, char ** argv) { std::vector tasks; std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true); for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_EMBEDDING); + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); @@ -3698,7 +3700,7 @@ int main(int argc, char ** argv) { std::vector tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true); tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < tokenized_docs.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_RERANK); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]); From 65d2e6d67534db4fabc0b3ebc576091eb911a620 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 7 Dec 2024 19:45:51 +0100 Subject: [PATCH 5/7] fix CI by adding safe_json_to_str --- examples/server/server.cpp | 10 +++++----- examples/server/utils.hpp | 4 ++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6e97ff7d07141..5edf774cbf07e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1657,7 +1657,7 @@ struct server_context { slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); - SLT_DBG(slot, "launching slot : %s\n", slot.to_json().dump().c_str()); + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? @@ -2942,12 +2942,12 @@ int main(int argc, char ** argv) { auto res_error = [](httplib::Response & res, const json & error_data) { json final_response {{"error", error_data}}; - res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); res.status = json_value(error_data, "code", 500); }; auto res_ok = [](httplib::Response & res, const json & data) { - res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.set_content(safe_json_to_str(data), MIMETYPE_JSON); res.status = 200; }; @@ -3524,7 +3524,7 @@ int main(int argc, char ** argv) { /* oaicompat_chat */ true); }; - const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { + const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { json models = { {"object", "list"}, {"data", { @@ -3538,7 +3538,7 @@ int main(int argc, char ** argv) { }} }; - res.set_content(models.dump(), MIMETYPE_JSON); + res_ok(res, models); }; const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 79c0a00fb36cb..8f545aea52dc4 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -660,3 +660,7 @@ static json format_logit_bias(const std::vector & logit_bias) } return data; } + +static std::string safe_json_to_str(json data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} From 1949f68f4e3ca4e7ed022d638800b4bddb35469d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 7 Dec 2024 20:05:04 +0100 Subject: [PATCH 6/7] add "model_path" to /props --- examples/server/server.cpp | 5 ++++- examples/server/tests/unit/test_basic.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5edf774cbf07e..1c21e55aaa011 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -801,7 +801,8 @@ struct server_task_result_metrics : server_task_result { uint64_t n_decode_total = 0; uint64_t n_busy_slots_total = 0; - // TODO: get rid of this json object and use to_json() instead + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result json slots_data = json::array(); virtual json to_json() override { @@ -3326,9 +3327,11 @@ int main(int argc, char ** argv) { }; const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + // this endpoint is publicly available, please only return what is safe to be exposed json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, + { "model_path", ctx_server.params_base.model }, { "chat_template", llama_get_chat_template(ctx_server.model) }, }; diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py index 22c6fe545903d..1d512401685fb 100644 --- a/examples/server/tests/unit/test_basic.py +++ b/examples/server/tests/unit/test_basic.py @@ -22,6 +22,7 @@ def test_server_props(): server.start() res = server.make_request("GET", "/props") assert res.status_code == 200 + assert ".gguf" in res.body["model_path"] assert res.body["total_slots"] == server.n_slots default_val = res.body["default_generation_settings"] assert server.n_ctx is not None and server.n_slots is not None From 89c2af9099793294565976fab7eba40a69a38a07 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 7 Dec 2024 20:07:51 +0100 Subject: [PATCH 7/7] update readme --- examples/server/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/server/README.md b/examples/server/README.md index 0bab40a82250c..117c52d3f1310 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -687,12 +687,14 @@ This endpoint is public (no API key check). By default, it is read-only. To make } }, "total_slots": 1, + "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", "chat_template": "..." } ``` - `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint. - `total_slots` - the total number of slots for process requests (defined by `--parallel` option) +- `model_path` - the path to model file (same with `-m` argument) - `chat_template` - the model's original Jinja2 prompt template ### POST `/props`: Change server global properties.