Skip to content

Commit 332bdfd

Browse files
server : maintain chat completion id for streaming responses (#5988)
* server: maintain chat completion id for streaming responses * Update examples/server/utils.hpp * Update examples/server/utils.hpp --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent ecab1c7 commit 332bdfd

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

Diff for: examples/server/server.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -3195,11 +3195,12 @@ int main(int argc, char ** argv) {
31953195
ctx_server.queue_results.add_waiting_task_id(id_task);
31963196
ctx_server.request_completion(id_task, -1, data, false, false);
31973197

3198+
const auto completion_id = gen_chatcmplid();
31983199
if (!json_value(data, "stream", false)) {
31993200
server_task_result result = ctx_server.queue_results.recv(id_task);
32003201

32013202
if (!result.error && result.stop) {
3202-
json result_oai = format_final_response_oaicompat(data, result.data);
3203+
json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
32033204

32043205
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
32053206
} else {
@@ -3208,11 +3209,11 @@ int main(int argc, char ** argv) {
32083209
}
32093210
ctx_server.queue_results.remove_waiting_task_id(id_task);
32103211
} else {
3211-
const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
3212+
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
32123213
while (true) {
32133214
server_task_result result = ctx_server.queue_results.recv(id_task);
32143215
if (!result.error) {
3215-
std::vector<json> result_array = format_partial_response_oaicompat(result.data);
3216+
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
32163217

32173218
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
32183219
if (!it->empty()) {

Diff for: examples/server/utils.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ static json oaicompat_completion_params_parse(
378378
return llama_params;
379379
}
380380

381-
static json format_final_response_oaicompat(const json & request, json result, bool streaming = false) {
381+
static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) {
382382
bool stopped_word = result.count("stopped_word") != 0;
383383
bool stopped_eos = json_value(result, "stopped_eos", false);
384384
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@@ -412,7 +412,7 @@ static json format_final_response_oaicompat(const json & request, json result, b
412412
{"prompt_tokens", num_prompt_tokens},
413413
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
414414
}},
415-
{"id", gen_chatcmplid()}
415+
{"id", completion_id}
416416
};
417417

418418
if (server_verbose) {
@@ -427,7 +427,7 @@ static json format_final_response_oaicompat(const json & request, json result, b
427427
}
428428

429429
// return value is vector as there is one case where we might need to generate two responses
430-
static std::vector<json> format_partial_response_oaicompat(json result) {
430+
static std::vector<json> format_partial_response_oaicompat(json result, const std::string & completion_id) {
431431
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
432432
return std::vector<json>({result});
433433
}
@@ -471,7 +471,7 @@ static std::vector<json> format_partial_response_oaicompat(json result) {
471471
{"role", "assistant"}
472472
}}}})},
473473
{"created", t},
474-
{"id", gen_chatcmplid()},
474+
{"id", completion_id},
475475
{"model", modelname},
476476
{"object", "chat.completion.chunk"}};
477477

@@ -482,7 +482,7 @@ static std::vector<json> format_partial_response_oaicompat(json result) {
482482
{"content", content}}}
483483
}})},
484484
{"created", t},
485-
{"id", gen_chatcmplid()},
485+
{"id", completion_id},
486486
{"model", modelname},
487487
{"object", "chat.completion.chunk"}};
488488

@@ -509,7 +509,7 @@ static std::vector<json> format_partial_response_oaicompat(json result) {
509509
json ret = json {
510510
{"choices", choices},
511511
{"created", t},
512-
{"id", gen_chatcmplid()},
512+
{"id", completion_id},
513513
{"model", modelname},
514514
{"object", "chat.completion.chunk"}
515515
};

0 commit comments

Comments
 (0)