Skip to content

Commit

Permalink
refactor: coding convention (#18)
Browse files Browse the repository at this point in the history
* refactor: better coding convention

* refactor: enum

* refactor: class variables

* refactor: llama_engine

---------

Co-authored-by: vansangpfiev <sang@jan.ai>
  • Loading branch information
vansangpfiev and sangjanai authored May 11, 2024
1 parent 7316198 commit 672084a
Show file tree
Hide file tree
Showing 9 changed files with 2,009 additions and 2,310 deletions.
7 changes: 4 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ if(UNIX AND NOT APPLE)
endif()

set(THIRD_PARTY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/build_deps/_install)
set(CORTEX_COMMON_PATH ${CMAKE_CURRENT_SOURCE_DIR}/base/)

if(UNIX AND NOT APPLE)
add_compile_options(-fPIC)
Expand Down Expand Up @@ -38,7 +37,9 @@ add_subdirectory(llama.cpp/examples/llava)
add_subdirectory(llama.cpp)

add_library(${TARGET} SHARED
src/LlamaEngine.cc
src/llama_engine.cc
src/llama_server_context.cc
src/llama_client_slot.cc
)

find_library(JSONCPP
Expand All @@ -54,6 +55,6 @@ find_library(TRANTOR
target_link_libraries(${TARGET} PRIVATE common llama llava ${JSONCPP} ${TRANTOR}
${CMAKE_THREAD_LIBS_INIT})
target_include_directories(${TARGET} PRIVATE
${CORTEX_COMMON_PATH}
${CMAKE_CURRENT_SOURCE_DIR}/base
${CMAKE_CURRENT_SOURCE_DIR}/llama.cpp
${THIRD_PARTY_PATH}/include)
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/server/server.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "cortex-common/EngineI.h"
#include "cortex-common/enginei.h"
#include "dylib.h"
#include "httplib.h"
#include "json/reader.h"
Expand Down
94 changes: 94 additions & 0 deletions src/llama_client_slot.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "llama_client_slot.h"

void LlamaClientSlot::Reset() {
num_prompt_tokens = 0;
generated_text = "";
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
n_past = 0;
sent_count = 0;
sent_token_probs_index = 0;
infill = false;

generated_token_probs.clear();

for (SlotImage& img : images) {
free(img.image_embedding);
if (img.img_data) {
clip_image_u8_free(img.img_data);
}
img.prefix_prompt = "";
}

images.clear();
}

bool LlamaClientSlot::HasBudget(gpt_params& global_params) {
n_remaining = -1;
if (params.n_predict != -1) {
n_remaining = params.n_predict - n_decoded;
} else if (global_params.n_predict != -1) {
n_remaining = global_params.n_predict - n_decoded;
}
return n_remaining > 0 || n_remaining == -1; // no budget || limitless
}

bool LlamaClientSlot::Available() const {
return state == SlotState::kIdle && command == SlotCommand::kNone;
}

bool LlamaClientSlot::IsProcessing() const {
return (state == SlotState::kIdle && command == SlotCommand::kLoadPrompt) ||
state == SlotState::kProcessing;
}

void LlamaClientSlot::AddTokenString(const CompletionTokenOutput& token) {
if (command == SlotCommand::kRelease) {
return;
}
generated_token_probs.push_back(token);
}

void LlamaClientSlot::Release() {
if (state == SlotState::kIdle || state == SlotState::kProcessing) {
t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3;
command = SlotCommand::kRelease;
}
}

json LlamaClientSlot::GetFormatedTimings() {
return json{
{"prompt_n", num_prompt_tokens_processed},
{"prompt_ms", t_prompt_processing},
{"prompt_per_token_ms",
t_prompt_processing / num_prompt_tokens_processed},
{"prompt_per_second",
1e3 / t_prompt_processing * num_prompt_tokens_processed},

{"predicted_n", n_decoded},
{"predicted_ms", t_token_generation},
{"predicted_per_token_ms", t_token_generation / n_decoded},
{"predicted_per_second", 1e3 / t_token_generation * n_decoded},
};
}

void LlamaClientSlot::PrintTimings() const {
LOG_DEBUG << __func__ << ": prompt eval time = " << t_prompt_processing
<< "ms / " << num_prompt_tokens_processed << " tokens ("
<< t_prompt_processing / num_prompt_tokens_processed
<< " ms per "
"token, "
<< 1e3 / t_prompt_processing * num_prompt_tokens_processed
<< " tokens per second)";
LOG_DEBUG << __func__ << ": eval time = " << t_token_generation
<< " ms / " << n_decoded << " runs ("
<< t_token_generation / n_decoded
<< " ms per "
"token, "
<< 1e3 / t_token_generation * n_decoded << " tokens per second)\n";
LOG_DEBUG << __func__ << ": total time = "
<< t_prompt_processing + t_token_generation << " ms";
}
170 changes: 170 additions & 0 deletions src/llama_client_slot.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#pragma once

#include <string>
#include <vector>

#include "common.h"
#include "json.hpp"
#include "llama.h"
#include "llava.h"
#include "stb_image.h"
#include "trantor/utils/Logger.h"

#include "clip.h"

static bool server_verbose = false;

#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
#endif

#if SERVER_VERBOSE != 1
#define LOG_VERBOSE(MSG, ...)
#else
#define LOG_VERBOSE(MSG, ...) \
do { \
if (server_verbose) { \
server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \
} \
} while (0)
#endif

#define LOG_ERROR_LLAMA(MSG, ...) \
server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_WARNING_LLAMA(MSG, ...) \
server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO_LLAMA(MSG, ...) \
server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)

using json = nlohmann::json;

enum class SlotState: uint8_t {
kIdle,
kProcessing,
};

enum class SlotCommand: uint8_t {
kNone,
kLoadPrompt,
kRelease,
};

struct SlotParams {
bool stream = true;
bool cache_prompt =
false; // remember the prompt to avoid reprocessing all prompt

uint32_t seed = -1; // RNG seed
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_predict = -1; // new tokens to predict

std::vector<std::string> antiprompt;

json input_prefix;
json input_suffix;
};

struct SlotImage {
int32_t id;

bool request_encode_image = false;
float* image_embedding = nullptr;
int32_t image_tokens = 0;

clip_image_u8* img_data;

std::string prefix_prompt; // before of this image
};

struct CompletionTokenOutput {
struct TokenProb {
llama_token tok;
float prob;
};

std::vector<TokenProb> probs;
llama_token tok;
std::string text_to_send;
};

struct LlamaClientSlot {
int id;
int task_id = -1;

struct SlotParams params;

SlotState state = SlotState::kIdle;
SlotCommand command = SlotCommand::kNone;

// used to determine the slot that has been used the longest
int64_t t_last_used = -1;

// generation props
int32_t n_ctx = 0; // context size per slot
int32_t n_past = 0;
int32_t n_decoded = 0;
int32_t n_remaining = -1;
int32_t i_batch = -1;

int32_t num_prompt_tokens = 0;
int32_t num_prompt_tokens_processed = 0;

json prompt;

// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;

std::string generated_text;
llama_token sampled;
std::vector<llama_token> cache_tokens;
std::vector<CompletionTokenOutput> generated_token_probs;

bool infill = false;
bool embedding = false;
bool has_next_token = true;
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
bool stopped_limit = false;

bool oaicompat = false;
std::string oaicompat_model;

std::string stopping_word;

// sampling
struct llama_sampling_params sparams;
llama_sampling_context* ctx_sampling = nullptr;

// multimodal
std::vector<SlotImage> images;

// stats
size_t sent_count = 0;
size_t sent_token_probs_index = 0;

int64_t t_start_process_prompt;
int64_t t_start_genereration;

double t_prompt_processing; // ms
double t_token_generation; // ms

// multitasks
int multitask_id = -1;

void Reset();

bool HasBudget(gpt_params& global_params);

bool Available() const;

bool IsProcessing() const;

void AddTokenString(const CompletionTokenOutput& token);

void Release();

json GetFormatedTimings();

void PrintTimings() const;
};
Loading

0 comments on commit 672084a

Please sign in to comment.