Skip to content

Commit

Permalink
Add openai compatible embedding (#283)
Browse files Browse the repository at this point in the history
* Add openai compatible embedding

* Add base64 encoding

* Fix comment
  • Loading branch information
nguyenhoangthuan99 authored Nov 8, 2024
1 parent 3a09f85 commit fda2c59
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 17 deletions.
4 changes: 3 additions & 1 deletion base/cortex-common/enginei.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <memory>

#include "json/value.h"
#include "trantor/utils/Logger.h"

// Interface for inference engine.
// Note: only append new function to keep the compatibility.
Expand Down Expand Up @@ -31,7 +32,7 @@ class EngineI {
virtual bool IsSupported(const std::string& f) {
if (f == "HandleChatCompletion" || f == "HandleEmbedding" ||
f == "LoadModel" || f == "UnloadModel" || f == "GetModelStatus" ||
f == "GetModels" || f == "SetFileLogger") {
f == "GetModels" || f == "SetFileLogger" || f == "SetLogLevel") {
return true;
}
return false;
Expand All @@ -44,4 +45,5 @@ class EngineI {
// API for set file logger
virtual void SetFileLogger(int max_log_lines,
const std::string& log_path) = 0;
virtual void SetLogLevel(trantor::Logger::LogLevel log_level) = 0;
};
104 changes: 89 additions & 15 deletions src/llama_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ bool IsValidCacheType(const std::string& c) {
return true;
}

bool AreAllElementsInt32(const Json::Value& arr) {
if (!arr.isArray()) {
return false;
}

for (const auto& element : arr) {
if (!element.isInt()) {
return false;
}
// Check if value is within int32_t range
auto value = element.asInt();
if (value < std::numeric_limits<int32_t>::min() ||
value > std::numeric_limits<int32_t>::max()) {
return false;
}
}
return true;
}

struct InferenceState {
int task_id;
LlamaServerContext& llama;
Expand All @@ -48,17 +67,25 @@ std::shared_ptr<InferenceState> CreateInferenceState(LlamaServerContext& l) {
}

Json::Value CreateEmbeddingPayload(const std::vector<float>& embedding,
int prompt_tokens) {
int index, bool is_base64) {
Json::Value dataItem;

dataItem["object"] = "embedding";
dataItem["index"] = index;

if (is_base64) {
// Convert float vector to bytes
auto base64_str =
llama_utils::base64Encode(llama_utils::FloatVectorToBytes(embedding));

Json::Value embeddingArray(Json::arrayValue);
for (const auto& value : embedding) {
embeddingArray.append(value);
dataItem["embedding"] = base64_str;
} else {
// Original array format
Json::Value embeddingArray(Json::arrayValue);
for (const auto& value : embedding) {
embeddingArray.append(value);
}
dataItem["embedding"] = embeddingArray;
}
dataItem["embedding"] = embeddingArray;
dataItem["index"] = 0;

return dataItem;
}
Expand Down Expand Up @@ -410,11 +437,16 @@ void LlamaEngine::GetModels(
LOG_INFO << "Running models responded";
}

void LlamaEngine::SetLogLevel(trantor::Logger::LogLevel log_level) {
trantor::Logger::setLogLevel(log_level);
}

void LlamaEngine::SetFileLogger(int max_log_lines,
const std::string& log_path) {
if (!async_file_logger_) {
async_file_logger_ = std::make_unique<trantor::FileLogger>();
}

async_file_logger_->setFileName(log_path);
async_file_logger_->setMaxLines(max_log_lines); // Keep last 100000 lines
async_file_logger_->startLogging();
Expand Down Expand Up @@ -991,6 +1023,8 @@ void LlamaEngine::HandleEmbeddingImpl(
request_id,
mid = std::move(model_id)]() {
Json::Value responseData(Json::arrayValue);
bool is_base64 =
(*json_body).get("encoding_format", "float").asString() == "base64";

int prompt_tokens = 0;
if (json_body->isMember("input")) {
Expand All @@ -1003,21 +1037,61 @@ void LlamaEngine::HandleEmbeddingImpl(
prompt_tokens +=
static_cast<int>(result.result_json["tokens_evaluated"]);
std::vector<float> embedding_result = result.result_json["embedding"];
responseData.append(CreateEmbeddingPayload(embedding_result, 0));
responseData.append(
CreateEmbeddingPayload(embedding_result, 0, is_base64));
} else if (input.isArray()) {
// Process each element in the array input
for (const auto& elem : input) {
if (elem.isString()) {
const int task_id = state->llama.RequestCompletion(
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true,
-1);
TaskResult result = state->llama.NextResult(task_id);
if (AreAllElementsInt32(input)) {
// Process the array of int32 tokens
state->task_id = state->llama.RequestCompletion(
{{"prompt", "Mock prompt"},
{"n_predict", 0},
{"prompt_tokens",
llama::inferences::ConvertJsonCppToNlohmann(input)}},
false, true, -1);
TaskResult result = state->llama.NextResult(state->task_id);
prompt_tokens +=
static_cast<int>(result.result_json["tokens_evaluated"]);
std::vector<float> embedding_result = result.result_json["embedding"];
responseData.append(
CreateEmbeddingPayload(embedding_result, 0, is_base64));
} else {

std::vector<int> task_ids;
int index = 0;
for (const auto& elem : input) {
if (elem.isString()) {
const int task_id = state->llama.RequestCompletion(
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true,
-1);

task_ids.push_back(task_id);
index++;
} else if (elem.isArray()) { // Check if elem is an array
bool all_int32 = AreAllElementsInt32(elem);

if (all_int32 && elem.size() > 0) {
// Convert token array to string representation for RequestCompletion

const int task_id = state->llama.RequestCompletion(
{{"prompt", "Mock prompt"},
{"n_predict", 0},
{"prompt_tokens",
llama::inferences::ConvertJsonCppToNlohmann(elem)}},
false, true, -1);
task_ids.push_back(task_id);
index++;
}
}
}
for (int i = 0; i < index; i++) {
TaskResult result = state->llama.NextResult(task_ids[i]);
int cur_pt = result.result_json["tokens_evaluated"];
prompt_tokens += cur_pt;
std::vector<float> embedding_result =
result.result_json["embedding"];
responseData.append(
CreateEmbeddingPayload(embedding_result, cur_pt));
CreateEmbeddingPayload(embedding_result, i, is_base64));
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/llama_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "llama.h"
#include "llama_server_context.h"
#include "trantor/utils/ConcurrentTaskQueue.h"
#include "trantor/utils/Logger.h"

class LlamaEngine : public EngineI {
public:
Expand All @@ -31,6 +32,8 @@ class LlamaEngine : public EngineI {
std::shared_ptr<Json::Value> jsonBody,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) final;
void SetFileLogger(int max_log_lines, const std::string& log_path) final;
void SetLogLevel(trantor::Logger::LogLevel log_level =
trantor::Logger::LogLevel::kInfo) final;

private:
bool LoadModelImpl(std::shared_ptr<Json::Value> jsonBody);
Expand Down
8 changes: 7 additions & 1 deletion src/llama_server_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,10 @@ bool LlamaServerContext::LaunchSlotWithData(LlamaClientSlot*& slot, json data) {
slot->sparams.ignore_eos =
json_value(data, "ignore_eos", default_sparams.ignore_eos);

slot->prompt_tokens =
json_value(data, "prompt_tokens", json::array()).get<std::vector<int>>();
slot->num_prompt_tokens = slot->prompt_tokens.size();

// infill
if (data.count("input_prefix") != 0) {
slot->params.input_prefix = data["input_prefix"];
Expand Down Expand Up @@ -637,7 +641,9 @@ bool LlamaServerContext::LaunchSlotWithData(LlamaClientSlot*& slot, json data) {
slot->smpl = common_sampler_init(model, slot->sparams);
// llama_set_rng_seed(ctx, slot->params.seed);
slot->command = SlotCommand::kLoadPrompt;
slot->prompt_tokens.clear();
if (slot->num_prompt_tokens == 0 && !slot->embedding) {
slot->prompt_tokens.clear();
}

all_slots_are_idle = false;

Expand Down
8 changes: 8 additions & 0 deletions src/llama_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ inline std::string extractBase64(const std::string& input) {
return "";
}

inline std::vector<unsigned char> FloatVectorToBytes(
const std::vector<float>& floats) {
const auto* float_bytes =
reinterpret_cast<const unsigned char*>(floats.data());
return std::vector<unsigned char>(
float_bytes, float_bytes + (floats.size() * sizeof(float)));
}

// Helper function to encode data to Base64
inline std::string base64Encode(const std::vector<unsigned char>& data) {
static const char encodingTable[] =
Expand Down

0 comments on commit fda2c59

Please sign in to comment.