Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: responses from /chat/completions endpoint contain a leading space in the content #488

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ add_compile_definitions(NITRO_VERSION="${NITRO_VERSION}")
add_subdirectory(llama.cpp/examples/llava)
add_subdirectory(llama.cpp)
add_subdirectory(whisper.cpp)
add_subdirectory(test)

add_executable(${PROJECT_NAME} main.cc)

Expand Down
55 changes: 38 additions & 17 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <fstream>
#include <iostream>
#include "log.h"
#include "utils/nitro_utils.h"
#include "utils/logging_utils.h"
#include "utils/nitro_utils.h"

// External
#include "common.h"
Expand All @@ -29,6 +29,8 @@ struct inferenceState {
int task_id;
InferenceStatus inference_status = PENDING;
llamaCPP* instance;
// Check if we receive the first token, set it to false after receiving
bool is_first_token = true;

inferenceState(llamaCPP* inst) : instance(inst) {}
};
Expand Down Expand Up @@ -208,7 +210,8 @@ void llamaCPP::InferenceImpl(

// Passing load value
data["repeat_last_n"] = this->repeat_last_n;
LOG_INFO_REQUEST(request_id) << "Stop words:" << completion.stop.toStyledString();
LOG_INFO_REQUEST(request_id)
<< "Stop words:" << completion.stop.toStyledString();

data["stream"] = completion.stream;
data["n_predict"] = completion.max_tokens;
Expand Down Expand Up @@ -267,7 +270,8 @@ void llamaCPP::InferenceImpl(
auto image_url = content_piece["image_url"]["url"].asString();
std::string base64_image_data;
if (image_url.find("http") != std::string::npos) {
LOG_INFO_REQUEST(request_id) << "Remote image detected but not supported yet";
LOG_INFO_REQUEST(request_id)
<< "Remote image detected but not supported yet";
} else if (image_url.find("data:image") != std::string::npos) {
LOG_INFO_REQUEST(request_id) << "Base64 image detected";
base64_image_data = nitro_utils::extractBase64(image_url);
Expand Down Expand Up @@ -328,16 +332,19 @@ void llamaCPP::InferenceImpl(
if (is_streamed) {
LOG_INFO_REQUEST(request_id) << "Streamed, waiting for respone";
auto state = create_inference_state(this);
auto chunked_content_provider =
[state, data, request_id](char* pBuffer, std::size_t nBuffSize) -> std::size_t {

auto chunked_content_provider = [state, data, request_id](
char* pBuffer,
std::size_t nBuffSize) -> std::size_t {
if (state->inference_status == PENDING) {
state->inference_status = RUNNING;
} else if (state->inference_status == FINISHED) {
return 0;
}

if (!pBuffer) {
LOG_WARN_REQUEST(request_id) "Connection closed or buffer is null. Reset context";
LOG_WARN_REQUEST(request_id)
"Connection closed or buffer is null. Reset context";
state->inference_status = FINISHED;
return 0;
}
Expand All @@ -350,7 +357,8 @@ void llamaCPP::InferenceImpl(
"stop") +
"\n\n" + "data: [DONE]" + "\n\n";

LOG_VERBOSE("data stream", {{"request_id": request_id}, {"to_send", str}});
LOG_VERBOSE("data stream",
{{"request_id": request_id}, {"to_send", str}});
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
state->inference_status = FINISHED;
Expand All @@ -359,7 +367,13 @@ void llamaCPP::InferenceImpl(

task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
const std::string to_send = result.result_json["content"];
std::string to_send = result.result_json["content"];

// trim the leading space if it is the first token
if (std::exchange(state->is_first_token, false)) {
nitro_utils::ltrim(to_send);
}

const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_",
Expand Down Expand Up @@ -410,7 +424,8 @@ void llamaCPP::InferenceImpl(
retries += 1;
}
if (state->inference_status != RUNNING)
LOG_INFO_REQUEST(request_id) << "Wait for task to be released:" << state->task_id;
LOG_INFO_REQUEST(request_id)
<< "Wait for task to be released:" << state->task_id;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
LOG_INFO_REQUEST(request_id) << "Task completed, release it";
Expand All @@ -428,9 +443,11 @@ void llamaCPP::InferenceImpl(
if (!result.error && result.stop) {
int prompt_tokens = result.result_json["tokens_evaluated"];
int predicted_tokens = result.result_json["tokens_predicted"];
respData = create_full_return_json(nitro_utils::generate_random_string(20),
"_", result.result_json["content"], "_",
prompt_tokens, predicted_tokens);
std::string to_send = result.result_json["content"];
nitro_utils::ltrim(to_send);
respData = create_full_return_json(
nitro_utils::generate_random_string(20), "_", to_send, "_",
prompt_tokens, predicted_tokens);
} else {
respData["message"] = "Internal error during inference";
LOG_ERROR_REQUEST(request_id) << "Error during inference";
Expand Down Expand Up @@ -463,7 +480,8 @@ void llamaCPP::EmbeddingImpl(
// Queue embedding task
auto state = create_inference_state(this);

state->instance->queue->runTaskInQueue([this, state, jsonBody, callback, request_id]() {
state->instance->queue->runTaskInQueue([this, state, jsonBody, callback,
request_id]() {
Json::Value responseData(Json::arrayValue);

if (jsonBody->isMember("input")) {
Expand Down Expand Up @@ -535,7 +553,7 @@ void llamaCPP::ModelStatus(
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
callback(resp);
LOG_INFO << "Model status responded";
}
}
}

void llamaCPP::LoadModel(
Expand All @@ -545,10 +563,12 @@ void llamaCPP::LoadModel(
if (!nitro_utils::isAVX2Supported() && ggml_cpu_has_avx2()) {
LOG_ERROR << "AVX2 is not supported by your processor";
Json::Value jsonResp;
jsonResp["message"] = "AVX2 is not supported by your processor, please download and replace the correct Nitro asset version";
jsonResp["message"] =
"AVX2 is not supported by your processor, please download and replace "
"the correct Nitro asset version";
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
resp->setStatusCode(drogon::k500InternalServerError);
callback(resp);
callback(resp);
return;
}

Expand Down Expand Up @@ -615,7 +635,8 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
if (model_path.isNull()) {
LOG_ERROR << "Missing model path in request";
} else {
if (std::filesystem::exists(std::filesystem::path(model_path.asString()))) {
if (std::filesystem::exists(
std::filesystem::path(model_path.asString()))) {
params.model = model_path.asString();
} else {
LOG_ERROR << "Could not find model in path " << model_path.asString();
Expand Down
4 changes: 2 additions & 2 deletions models/chat_completion_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ namespace inferences {
struct ChatCompletionRequest {
bool stream = false;
int max_tokens = 500;
float top_p = 0.95;
float temperature = 0.8;
float top_p = 0.95f;
float temperature = 0.8f;
float frequency_penalty = 0;
float presence_penalty = 0;
Json::Value stop = Json::Value(Json::arrayValue);
Expand Down
15 changes: 14 additions & 1 deletion nitro_deps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,21 @@ ExternalProject_Add(
-DCMAKE_INSTALL_PREFIX=${THIRD_PARTY_INSTALL_PATH}
)

# Fix trantor cmakelists to link c-ares on Windows
# Download and install GoogleTest
ExternalProject_Add(
gtest
GIT_REPOSITORY https://github.com/google/googletest
GIT_TAG v1.14.0
CMAKE_ARGS
-Dgtest_force_shared_crt=ON
-DCMAKE_BUILD_TYPE=release
-DCMAKE_PREFIX_PATH=${THIRD_PARTY_INSTALL_PATH}
-DCMAKE_INSTALL_PREFIX=${THIRD_PARTY_INSTALL_PATH}
)


if(WIN32)
# Fix trantor cmakelists to link c-ares on Windows
set(TRANTOR_CMAKE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/../build_deps/nitro_deps/drogon-prefix/src/drogon/trantor/CMakeLists.txt)
ExternalProject_Add_Step(drogon trantor_custom_target
COMMAND ${CMAKE_COMMAND} -E echo add_definitions(-DCARES_STATICLIB) >> ${TRANTOR_CMAKE_FILE}
Expand Down
2 changes: 2 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

add_subdirectory(components)
16 changes: 16 additions & 0 deletions test/components/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
file(GLOB SRCS *.cc)
project(test-components)

enable_testing()

add_executable(${PROJECT_NAME} ${SRCS})

find_package(Drogon CONFIG REQUIRED)
find_package(GTest CONFIG REQUIRED)

target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gmock
${CMAKE_THREAD_LIBS_INIT})
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)

add_test(NAME ${PROJECT_NAME}
COMMAND ${PROJECT_NAME})
9 changes: 9 additions & 0 deletions test/components/main.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include "gtest/gtest.h"
#include <drogon/HttpAppFramework.h>
#include <drogon/drogon.h>

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}
53 changes: 53 additions & 0 deletions test/components/test_models.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "gtest/gtest.h"
#include "models/chat_completion_request.h"

using inferences::ChatCompletionRequest;

class ModelTest : public ::testing::Test {
};


TEST_F(ModelTest, should_parse_request) {
{
Json::Value data;
auto req = drogon::HttpRequest::newHttpJsonRequest(data);

auto res =
drogon::fromRequest<inferences::ChatCompletionRequest>(*req.get());

EXPECT_EQ(res.stream, false);
EXPECT_EQ(res.max_tokens, 500);
EXPECT_EQ(res.top_p, 0.95f);
EXPECT_EQ(res.temperature, 0.8f);
EXPECT_EQ(res.frequency_penalty, 0);
EXPECT_EQ(res.presence_penalty, 0);
EXPECT_EQ(res.stop, Json::Value{});
EXPECT_EQ(res.messages, Json::Value{});
}

{
Json::Value data;
data["stream"] = true;
data["max_tokens"] = 400;
data["top_p"] = 0.8;
data["temperature"] = 0.7;
data["frequency_penalty"] = 0.1;
data["presence_penalty"] = 0.2;
data["messages"] = "message";
data["stop"] = "stop";

auto req = drogon::HttpRequest::newHttpJsonRequest(data);

auto res =
drogon::fromRequest<inferences::ChatCompletionRequest>(*req.get());

EXPECT_EQ(res.stream, true);
EXPECT_EQ(res.max_tokens, 400);
EXPECT_EQ(res.top_p, 0.8f);
EXPECT_EQ(res.temperature, 0.7f);
EXPECT_EQ(res.frequency_penalty, 0.1f);
EXPECT_EQ(res.presence_penalty, 0.2f);
EXPECT_EQ(res.stop, Json::Value{"stop"});
EXPECT_EQ(res.messages, Json::Value{"message"});
}
}
41 changes: 41 additions & 0 deletions test/components/test_nitro_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "gtest/gtest.h"
#include "utils/nitro_utils.h"

class NitroUtilTest : public ::testing::Test {
};

TEST_F(NitroUtilTest, left_trim) {
{
std::string empty;
nitro_utils::ltrim(empty);
EXPECT_EQ(empty, "");
}

{
std::string s = "abc";
std::string expected = "abc";
nitro_utils::ltrim(s);
EXPECT_EQ(s, expected);
}

{
std::string s = " abc";
std::string expected = "abc";
nitro_utils::ltrim(s);
EXPECT_EQ(s, expected);
}

{
std::string s = "1 abc 2 ";
std::string expected = "1 abc 2 ";
nitro_utils::ltrim(s);
EXPECT_EQ(s, expected);
}

{
std::string s = " |abc";
std::string expected = "|abc";
nitro_utils::ltrim(s);
EXPECT_EQ(s, expected);
}
}
8 changes: 7 additions & 1 deletion utils/nitro_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ inline std::string generate_random_string(std::size_t length) {
std::random_device rd;
std::mt19937 generator(rd());

std::uniform_int_distribution<> distribution(0, characters.size() - 1);
std::uniform_int_distribution<> distribution(0, static_cast<int>(characters.size()) - 1);

std::string random_string(length, '\0');
std::generate_n(random_string.begin(), length,
Expand Down Expand Up @@ -276,4 +276,10 @@ inline drogon::HttpResponsePtr nitroStreamResponse(
return resp;
}

inline void ltrim(std::string& s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
return !std::isspace(ch);
}));
};

} // namespace nitro_utils
Loading