-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* refactor: better coding convention * refactor: enum * refactor: class variables * refactor: llama_engine --------- Co-authored-by: vansangpfiev <sang@jan.ai>
- Loading branch information
1 parent
7316198
commit 672084a
Showing
9 changed files
with
2,009 additions
and
2,310 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
#include "llama_client_slot.h" | ||
|
||
void LlamaClientSlot::Reset() { | ||
num_prompt_tokens = 0; | ||
generated_text = ""; | ||
truncated = false; | ||
stopped_eos = false; | ||
stopped_word = false; | ||
stopped_limit = false; | ||
stopping_word = ""; | ||
n_past = 0; | ||
sent_count = 0; | ||
sent_token_probs_index = 0; | ||
infill = false; | ||
|
||
generated_token_probs.clear(); | ||
|
||
for (SlotImage& img : images) { | ||
free(img.image_embedding); | ||
if (img.img_data) { | ||
clip_image_u8_free(img.img_data); | ||
} | ||
img.prefix_prompt = ""; | ||
} | ||
|
||
images.clear(); | ||
} | ||
|
||
bool LlamaClientSlot::HasBudget(gpt_params& global_params) { | ||
n_remaining = -1; | ||
if (params.n_predict != -1) { | ||
n_remaining = params.n_predict - n_decoded; | ||
} else if (global_params.n_predict != -1) { | ||
n_remaining = global_params.n_predict - n_decoded; | ||
} | ||
return n_remaining > 0 || n_remaining == -1; // no budget || limitless | ||
} | ||
|
||
bool LlamaClientSlot::Available() const { | ||
return state == SlotState::kIdle && command == SlotCommand::kNone; | ||
} | ||
|
||
bool LlamaClientSlot::IsProcessing() const { | ||
return (state == SlotState::kIdle && command == SlotCommand::kLoadPrompt) || | ||
state == SlotState::kProcessing; | ||
} | ||
|
||
void LlamaClientSlot::AddTokenString(const CompletionTokenOutput& token) { | ||
if (command == SlotCommand::kRelease) { | ||
return; | ||
} | ||
generated_token_probs.push_back(token); | ||
} | ||
|
||
void LlamaClientSlot::Release() { | ||
if (state == SlotState::kIdle || state == SlotState::kProcessing) { | ||
t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; | ||
command = SlotCommand::kRelease; | ||
} | ||
} | ||
|
||
json LlamaClientSlot::GetFormatedTimings() { | ||
return json{ | ||
{"prompt_n", num_prompt_tokens_processed}, | ||
{"prompt_ms", t_prompt_processing}, | ||
{"prompt_per_token_ms", | ||
t_prompt_processing / num_prompt_tokens_processed}, | ||
{"prompt_per_second", | ||
1e3 / t_prompt_processing * num_prompt_tokens_processed}, | ||
|
||
{"predicted_n", n_decoded}, | ||
{"predicted_ms", t_token_generation}, | ||
{"predicted_per_token_ms", t_token_generation / n_decoded}, | ||
{"predicted_per_second", 1e3 / t_token_generation * n_decoded}, | ||
}; | ||
} | ||
|
||
void LlamaClientSlot::PrintTimings() const { | ||
LOG_DEBUG << __func__ << ": prompt eval time = " << t_prompt_processing | ||
<< "ms / " << num_prompt_tokens_processed << " tokens (" | ||
<< t_prompt_processing / num_prompt_tokens_processed | ||
<< " ms per " | ||
"token, " | ||
<< 1e3 / t_prompt_processing * num_prompt_tokens_processed | ||
<< " tokens per second)"; | ||
LOG_DEBUG << __func__ << ": eval time = " << t_token_generation | ||
<< " ms / " << n_decoded << " runs (" | ||
<< t_token_generation / n_decoded | ||
<< " ms per " | ||
"token, " | ||
<< 1e3 / t_token_generation * n_decoded << " tokens per second)\n"; | ||
LOG_DEBUG << __func__ << ": total time = " | ||
<< t_prompt_processing + t_token_generation << " ms"; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "common.h" | ||
#include "json.hpp" | ||
#include "llama.h" | ||
#include "llava.h" | ||
#include "stb_image.h" | ||
#include "trantor/utils/Logger.h" | ||
|
||
#include "clip.h" | ||
|
||
static bool server_verbose = false; | ||
|
||
#ifndef SERVER_VERBOSE | ||
#define SERVER_VERBOSE 1 | ||
#endif | ||
|
||
#if SERVER_VERBOSE != 1 | ||
#define LOG_VERBOSE(MSG, ...) | ||
#else | ||
#define LOG_VERBOSE(MSG, ...) \ | ||
do { \ | ||
if (server_verbose) { \ | ||
server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ | ||
} \ | ||
} while (0) | ||
#endif | ||
|
||
#define LOG_ERROR_LLAMA(MSG, ...) \ | ||
server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) | ||
#define LOG_WARNING_LLAMA(MSG, ...) \ | ||
server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) | ||
#define LOG_INFO_LLAMA(MSG, ...) \ | ||
server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) | ||
|
||
using json = nlohmann::json; | ||
|
||
enum class SlotState: uint8_t { | ||
kIdle, | ||
kProcessing, | ||
}; | ||
|
||
enum class SlotCommand: uint8_t { | ||
kNone, | ||
kLoadPrompt, | ||
kRelease, | ||
}; | ||
|
||
struct SlotParams { | ||
bool stream = true; | ||
bool cache_prompt = | ||
false; // remember the prompt to avoid reprocessing all prompt | ||
|
||
uint32_t seed = -1; // RNG seed | ||
int32_t n_keep = 0; // number of tokens to keep from initial prompt | ||
int32_t n_predict = -1; // new tokens to predict | ||
|
||
std::vector<std::string> antiprompt; | ||
|
||
json input_prefix; | ||
json input_suffix; | ||
}; | ||
|
||
struct SlotImage { | ||
int32_t id; | ||
|
||
bool request_encode_image = false; | ||
float* image_embedding = nullptr; | ||
int32_t image_tokens = 0; | ||
|
||
clip_image_u8* img_data; | ||
|
||
std::string prefix_prompt; // before of this image | ||
}; | ||
|
||
struct CompletionTokenOutput { | ||
struct TokenProb { | ||
llama_token tok; | ||
float prob; | ||
}; | ||
|
||
std::vector<TokenProb> probs; | ||
llama_token tok; | ||
std::string text_to_send; | ||
}; | ||
|
||
struct LlamaClientSlot { | ||
int id; | ||
int task_id = -1; | ||
|
||
struct SlotParams params; | ||
|
||
SlotState state = SlotState::kIdle; | ||
SlotCommand command = SlotCommand::kNone; | ||
|
||
// used to determine the slot that has been used the longest | ||
int64_t t_last_used = -1; | ||
|
||
// generation props | ||
int32_t n_ctx = 0; // context size per slot | ||
int32_t n_past = 0; | ||
int32_t n_decoded = 0; | ||
int32_t n_remaining = -1; | ||
int32_t i_batch = -1; | ||
|
||
int32_t num_prompt_tokens = 0; | ||
int32_t num_prompt_tokens_processed = 0; | ||
|
||
json prompt; | ||
|
||
// when a task is submitted, we first tokenize the prompt and store it here | ||
std::vector<llama_token> prompt_tokens; | ||
|
||
std::string generated_text; | ||
llama_token sampled; | ||
std::vector<llama_token> cache_tokens; | ||
std::vector<CompletionTokenOutput> generated_token_probs; | ||
|
||
bool infill = false; | ||
bool embedding = false; | ||
bool has_next_token = true; | ||
bool truncated = false; | ||
bool stopped_eos = false; | ||
bool stopped_word = false; | ||
bool stopped_limit = false; | ||
|
||
bool oaicompat = false; | ||
std::string oaicompat_model; | ||
|
||
std::string stopping_word; | ||
|
||
// sampling | ||
struct llama_sampling_params sparams; | ||
llama_sampling_context* ctx_sampling = nullptr; | ||
|
||
// multimodal | ||
std::vector<SlotImage> images; | ||
|
||
// stats | ||
size_t sent_count = 0; | ||
size_t sent_token_probs_index = 0; | ||
|
||
int64_t t_start_process_prompt; | ||
int64_t t_start_genereration; | ||
|
||
double t_prompt_processing; // ms | ||
double t_token_generation; // ms | ||
|
||
// multitasks | ||
int multitask_id = -1; | ||
|
||
void Reset(); | ||
|
||
bool HasBudget(gpt_params& global_params); | ||
|
||
bool Available() const; | ||
|
||
bool IsProcessing() const; | ||
|
||
void AddTokenString(const CompletionTokenOutput& token); | ||
|
||
void Release(); | ||
|
||
json GetFormatedTimings(); | ||
|
||
void PrintTimings() const; | ||
}; |
Oops, something went wrong.