Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for chat.completion to return sse response #2910

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 82 additions & 183 deletions gpt4all-chat/src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
#include <QtLogging>

#include <iostream>
#include <string>
#include <type_traits>
#include <utility>

using namespace Qt::Literals::StringLiterals;
Expand Down Expand Up @@ -207,26 +205,29 @@ void Server::start()

QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &request, bool isChat)
{
// We've been asked to do a completion...
// Parse JSON request
QJsonParseError err;
const QJsonDocument document = QJsonDocument::fromJson(request.body(), &err);
if (err.error || !document.isObject()) {
std::cerr << "ERROR: invalid json in completions body" << std::endl;
std::cerr << "ERROR: invalid JSON in completions body" << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
}

#if defined(DEBUG)
printf("/v1/completions %s\n", qPrintable(document.toJson(QJsonDocument::Indented)));
fflush(stdout);
#endif

const QJsonObject body = document.object();
if (!body.contains("model")) { // required
std::cerr << "ERROR: completions contains no model" << std::endl;
if (!body.contains("model")) {
std::cerr << "ERROR: completions contain no model" << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
}

QJsonArray messages;
if (isChat) {
if (!body.contains("messages")) {
std::cerr << "ERROR: chat completions contains no messages" << std::endl;
std::cerr << "ERROR: chat completions contain no messages" << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
}
messages = body["messages"].toArray();
Expand All @@ -236,16 +237,12 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
for (const ModelInfo &info : modelList) {
Q_ASSERT(info.installed);
if (!info.installed)
continue;
if (modelRequested == info.name() || modelRequested == info.filename()) {
if (info.installed && (modelRequested == info.name() || modelRequested == info.filename())) {
modelInfo = info;
break;
}
}

// We only support one prompt for now
QList<QString> prompts;
if (body.contains("prompt")) {
QJsonValue promptValue = body["prompt"];
Expand All @@ -256,217 +253,119 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
for (const QJsonValue &v : array)
prompts.append(v.toString());
}
} else
} else {
prompts.append(" ");

int max_tokens = 16;
if (body.contains("max_tokens"))
max_tokens = body["max_tokens"].toInt();

float temperature = 1.f;
if (body.contains("temperature"))
temperature = body["temperature"].toDouble();

float top_p = 1.f;
if (body.contains("top_p"))
top_p = body["top_p"].toDouble();

float min_p = 0.f;
if (body.contains("min_p"))
min_p = body["min_p"].toDouble();

int n = 1;
if (body.contains("n"))
n = body["n"].toInt();

int logprobs = -1; // supposed to be null by default??
if (body.contains("logprobs"))
logprobs = body["logprobs"].toInt();

bool echo = false;
if (body.contains("echo"))
echo = body["echo"].toBool();

// We currently don't support any of the following...
#if 0
// FIXME: Need configurable reverse prompts
QList<QString> stop;
if (body.contains("stop")) {
QJsonValue stopValue = body["stop"];
if (stopValue.isString())
stop.append(stopValue.toString());
else {
QJsonArray array = stopValue.toArray();
for (QJsonValue v : array)
stop.append(v.toString());
}
}

// FIXME: QHttpServer doesn't support server-sent events
bool stream = false;
if (body.contains("stream"))
stream = body["stream"].toBool();

// FIXME: What does this do?
QString suffix;
if (body.contains("suffix"))
suffix = body["suffix"].toString();

// FIXME: We don't support
float presence_penalty = 0.f;
if (body.contains("presence_penalty"))
top_p = body["presence_penalty"].toDouble();

// FIXME: We don't support
float frequency_penalty = 0.f;
if (body.contains("frequency_penalty"))
top_p = body["frequency_penalty"].toDouble();

// FIXME: We don't support
int best_of = 1;
if (body.contains("best_of"))
logprobs = body["best_of"].toInt();

// FIXME: We don't need
QString user;
if (body.contains("user"))
suffix = body["user"].toString();
#endif
int max_tokens = body.value("max_tokens").toInt(16);
float temperature = body.value("temperature").toDouble(1.0);
float top_p = body.value("top_p").toDouble(1.0);
float min_p = body.value("min_p").toDouble(0.0);
int n = body.value("n").toInt(1);
bool echo = body.value("echo").toBool(false);

QString actualPrompt = prompts.first();

// if we're a chat completion we have messages which means we need to prepend these to the prompt
if (!messages.isEmpty()) {
QList<QString> chats;
for (int i = 0; i < messages.count(); ++i) {
QJsonValue v = messages.at(i);
// FIXME: Deal with system messages correctly
QString role = v.toObject()["role"].toString();
if (role != "user")
continue;
QString content = v.toObject()["content"].toString();
for (int i = 0; i < messages.count(); ++i) {
QString content = messages.at(i).toObject()["content"].toString();
if (!content.endsWith("\n") && i < messages.count() - 1)
content += "\n";
chats.append(content);
}
actualPrompt.prepend(chats.join("\n"));
}

// adds prompt/response items to GUI
emit requestServerNewPromptResponsePair(actualPrompt); // blocks

// load the new model if necessary
setShouldBeLoaded(true);

if (modelInfo.filename().isEmpty()) {
std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::BadRequest);
}

// NB: this resets the context, regardless of whether this model is already loaded
if (!loadModel(modelInfo)) {
} else if (!loadModel(modelInfo)) {
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError);
}

const QString promptTemplate = modelInfo.promptTemplate();
const float top_k = modelInfo.topK();
const int n_batch = modelInfo.promptBatchSize();
const float repeat_penalty = modelInfo.repeatPenalty();
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
resetContext();

QByteArray responseData;
QTextStream stream(&responseData, QIODevice::WriteOnly);

QString randomId = "chatcmpl-" + QUuid::createUuid().toString(QUuid::WithoutBraces).replace("-", "");

int promptTokens = 0;
int responseTokens = 0;
QList<QPair<QString, QList<ResultInfo>>> responses;
for (int i = 0; i < n; ++i) {
if (!promptInternal(
m_collections,
actualPrompt,
promptTemplate,
max_tokens /*n_predict*/,
top_k,
top_p,
min_p,
temperature,
n_batch,
repeat_penalty,
repeat_last_n)) {
if (!promptInternal(m_collections,
actualPrompt,
modelInfo.promptTemplate(),
max_tokens /*n_predict*/,
modelInfo.topK(),
top_p,
min_p,
temperature,
modelInfo.promptBatchSize(),
modelInfo.repeatPenalty(),
modelInfo.repeatPenaltyTokens())) {

std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError);
}
QString echoedPrompt = actualPrompt;
if (!echoedPrompt.endsWith("\n"))
echoedPrompt += "\n";
responses.append(qMakePair((echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response(), m_databaseResults));
if (!promptTokens)
promptTokens += m_promptTokens;
responseTokens += m_promptResponseTokens - m_promptTokens;
if (i != n - 1)
resetResponse();
}

QJsonObject responseObject;
responseObject.insert("id", "foobarbaz");
responseObject.insert("object", "text_completion");
responseObject.insert("created", QDateTime::currentSecsSinceEpoch());
responseObject.insert("model", modelInfo.name());
QString result = (echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response();

QJsonArray choices;
for (const QString &token : result.split(' ')) {
QJsonObject delta;
delta.insert("content", token + " ");

if (isChat) {
int index = 0;
for (const auto &r : responses) {
QString result = r.first;
QList<ResultInfo> infos = r.second;
QJsonObject choice;
choice.insert("index", index++);
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
QJsonObject message;
message.insert("role", "assistant");
message.insert("content", result);
choice.insert("message", message);
if (MySettings::globalInstance()->localDocsShowReferences()) {
QJsonArray references;
for (const auto &ref : infos)
references.append(resultToJson(ref));
choice.insert("references", references);
}
choices.append(choice);
}
} else {
int index = 0;
for (const auto &r : responses) {
QString result = r.first;
QList<ResultInfo> infos = r.second;
QJsonObject choice;
choice.insert("text", result);
choice.insert("index", index++);
choice.insert("logprobs", QJsonValue::Null); // We don't support
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
if (MySettings::globalInstance()->localDocsShowReferences()) {
QJsonArray references;
for (const auto &ref : infos)
references.append(resultToJson(ref));
choice.insert("references", references);
}
choices.append(choice);
choice.insert("index", i);
choice.insert("delta", delta);

QJsonObject responseChunk;
responseChunk.insert("id", randomId);
responseChunk.insert("object", "chat.completion.chunk");
responseChunk.insert("created", QDateTime::currentSecsSinceEpoch());
responseChunk.insert("model", modelInfo.name());
responseChunk.insert("choices", QJsonArray{choice});

stream << "data: " << QJsonDocument(responseChunk).toJson(QJsonDocument::Compact) << "\n\n";
stream.flush();
}

if (i != n - 1)
resetResponse();
}

responseObject.insert("choices", choices);
// Final empty delta to signify the end of the stream
QJsonObject delta;
delta.insert("content", QJsonValue::Null);

QJsonObject usage;
usage.insert("prompt_tokens", int(promptTokens));
usage.insert("completion_tokens", int(responseTokens));
usage.insert("total_tokens", int(promptTokens + responseTokens));
responseObject.insert("usage", usage);
QJsonObject choice;
choice.insert("index", 0);
choice.insert("delta", delta);
choice.insert("finish_reason", "stop");

#if defined(DEBUG)
QJsonDocument newDoc(responseObject);
printf("/v1/completions %s\n", qPrintable(newDoc.toJson(QJsonDocument::Indented)));
fflush(stdout);
#endif
QJsonObject finalChunk;
finalChunk.insert("id", randomId);
finalChunk.insert("object", "chat.completion.chunk");
finalChunk.insert("created", QDateTime::currentSecsSinceEpoch());
finalChunk.insert("model", modelInfo.name());
finalChunk.insert("choices", QJsonArray{choice});

stream << "data: " << QJsonDocument(finalChunk).toJson(QJsonDocument::Compact) << "\n\n";
stream << "data: [DONE]\n\n";
stream.flush();

// Log the entire response data
qDebug() << "Full SSE Response:\n" << responseData;

// Create the response
QHttpServerResponse response(responseData, QHttpServerResponder::StatusCode::Ok);
response.setHeader("Content-Type", "text/event-stream");
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");

return QHttpServerResponse(responseObject);
return response;
}