Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add single-client multi-prompt support #4232

Merged
merged 11 commits into from
Nov 30, 2023
140 changes: 129 additions & 11 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <thread>
#include <mutex>
#include <chrono>
#include <unordered_set>

#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
Expand Down Expand Up @@ -155,15 +156,23 @@ struct task_server {
json data;
bool infill_mode = false;
bool embedding_mode = false;
int multitask_id = -1;
};

struct task_result {
int id;
int multitask_id = -1;
bool stop;
bool error;
json result_json;
};

struct task_multi {
int id;
std::unordered_set<int> subtasks_remaining;
ziedbha marked this conversation as resolved.
Show resolved Hide resolved
std::vector<task_result> results;
};

// TODO: can become bool if we can't find use of more states
enum slot_state
{
Expand Down Expand Up @@ -406,6 +415,9 @@ struct llama_client_slot
double t_prompt_processing; // ms
double t_token_generation; // ms

// multitasks
int multitask_id = -1;

void reset() {
num_prompt_tokens = 0;
generated_text = "";
Expand Down Expand Up @@ -529,7 +541,8 @@ struct llama_server_context

std::vector<task_server> queue_tasks;
std::vector<task_result> queue_results;
std::mutex mutex_tasks;
std::vector<task_multi> queue_multitasks;
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
std::mutex mutex_results;

~llama_server_context()
Expand Down Expand Up @@ -1112,17 +1125,40 @@ struct llama_server_context
return slot.images.size() > 0;
}

void send_error(int id, std::string error)
void send_error(task_server& task, std::string error)
{
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = id;
res.id = task.id;
res.multitask_id = task.multitask_id;
res.stop = false;
res.error = true;
res.result_json = { { "content", error } };
queue_results.push_back(res);
}

void add_multi_task(int id, std::vector<int>& sub_ids)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
task_multi multi;
multi.id = id;
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
queue_multitasks.push_back(multi);
}

void update_multi_task(int multitask_id, int subtask_id, task_result& result)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
for (auto& multitask : queue_multitasks)
{
if (multitask.id == multitask_id)
{
multitask.subtasks_remaining.erase(subtask_id);
multitask.results.push_back(result);
}
}
}

json get_model_props()
{
return get_formated_generation(slots[0]);
Expand Down Expand Up @@ -1167,6 +1203,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = false;

Expand Down Expand Up @@ -1206,6 +1243,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = true;

Expand Down Expand Up @@ -1251,6 +1289,12 @@ struct llama_server_context
res.result_json["model"] = slot.oaicompat_model;
}

// parent multitask, if any, needs to be updated
if (slot.multitask_id != -1)
{
update_multi_task(slot.multitask_id, slot.task_id, res);
}

queue_results.push_back(res);
}

Expand All @@ -1259,6 +1303,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = true;

Expand All @@ -1285,16 +1330,26 @@ struct llama_server_context
queue_results.push_back(res);
}

int request_completion(json data, bool infill, bool embedding)
int request_completion(json data, bool infill, bool embedding, int multitask_id)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task;
task.id = id_gen++;
task.target_id = 0;
task.data = std::move(data);
task.infill_mode = infill;
task.embedding_mode = embedding;
task.type = COMPLETION_TASK;
task.multitask_id = multitask_id;

// when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.at("prompt").size() > 1)
{
lock.unlock(); // entering new func scope
return split_multiprompt_task(task);
}

// otherwise, it's a single-prompt task, we actually queue it
queue_tasks.push_back(task);
return task.id;
}
Expand All @@ -1313,8 +1368,17 @@ struct llama_server_context

for (int i = 0; i < (int) queue_results.size(); i++)
{
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
if (queue_results[i].multitask_id == task_id)
{
update_multi_task(task_id, queue_results[i].id, queue_results[i]);
queue_results.erase(queue_results.begin() + i);
continue;
}

if (queue_results[i].id == task_id)
{
assert(queue_results[i].multitask_id == -1);
task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i);
return res;
Expand Down Expand Up @@ -1404,6 +1468,27 @@ struct llama_server_context
queue_tasks.push_back(task);
}

int split_multiprompt_task(task_server& multiprompt_task)
{
auto prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1);

int multitask_id = id_gen++;
std::vector<int> subtask_ids(prompt_count);
for (int i = 0; i < prompt_count; i++)
{
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i];

// subtasks inherit everything else (infill mode, embedding mode, etc.)
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
}

// queue up the multitask so we can track its subtask progression
add_multi_task(multitask_id, subtask_ids);
return multitask_id;
}

void process_tasks()
{
std::lock_guard<std::mutex> lock(mutex_tasks);
Expand All @@ -1419,7 +1504,7 @@ struct llama_server_context
{
LOG_TEE("slot unavailable\n");
// send error result
send_error(task.id, "slot unavailable");
send_error(task, "slot unavailable");
return;
}

Expand All @@ -1433,11 +1518,12 @@ struct llama_server_context
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode;
slot->task_id = task.id;
slot->multitask_id = task.multitask_id;

if (!launch_slot_with_data(slot, task.data))
{
// send error result
send_error(task.id, "internal_error");
send_error(task, "internal_error");
break;
}
} break;
Expand All @@ -1453,6 +1539,38 @@ struct llama_server_context
} break;
}
}

// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
auto queue_iterator = queue_multitasks.begin();
while (queue_iterator != queue_multitasks.end())
{
if (queue_iterator->subtasks_remaining.empty())
{
// all subtasks done == multitask is done
task_result aggregate_result;
aggregate_result.id = queue_iterator->id;
aggregate_result.stop = true;
aggregate_result.error = false;

// collect json results into one json result
std::vector<json> result_jsons;
for (auto& subres : queue_iterator->results)
{
result_jsons.push_back(subres.result_json);
aggregate_result.error = aggregate_result.error && subres.error;
}
aggregate_result.result_json = json{ "results", result_jsons };

std::lock_guard<std::mutex> lock(mutex_results);
queue_results.push_back(aggregate_result);

queue_iterator = queue_multitasks.erase(queue_iterator);
}
else
{
++queue_iterator;
}
}
}

bool update_slots() {
Expand Down Expand Up @@ -2596,7 +2714,7 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
{
json data = json::parse(req.body);
const int task_id = llama.request_completion(data, false, false);
const int task_id = llama.request_completion(data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
Expand Down Expand Up @@ -2685,7 +2803,7 @@ int main(int argc, char **argv)
{
json data = oaicompat_completion_params_parse(json::parse(req.body));

const int task_id = llama.request_completion(data, false, false);
const int task_id = llama.request_completion(data, false, false, -1);

if (!json_value(data, "stream", false)) {
std::string completion_text;
Expand Down Expand Up @@ -2754,7 +2872,7 @@ int main(int argc, char **argv)
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
{
json data = json::parse(req.body);
const int task_id = llama.request_completion(data, true, false);
const int task_id = llama.request_completion(data, true, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
Expand Down Expand Up @@ -2858,7 +2976,7 @@ int main(int argc, char **argv)
{
prompt = "";
}
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
task_result result = llama.next_result(task_id);
return res.set_content(result.result_json.dump(), "application/json");
});
Expand Down
Loading