Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for file load progress reporting callbacks #434

Merged
merged 8 commits into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,16 @@ static void kv_cache_free(struct llama_kv_cache & cache) {

struct llama_context_params llama_context_default_params() {
struct llama_context_params result = {
/*.n_ctx =*/ 512,
/*.n_parts =*/ -1,
/*.seed =*/ 0,
/*.f16_kv =*/ false,
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
/*.use_mlock =*/ false,
/*.embedding =*/ false,
/*.n_ctx =*/ 512,
/*.n_parts =*/ -1,
/*.seed =*/ 0,
/*.f16_kv =*/ false,
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
/*.use_mlock =*/ false,
/*.embedding =*/ false,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
};

return result;
Expand All @@ -290,7 +292,9 @@ static bool llama_model_load(
int n_ctx,
int n_parts,
ggml_type memory_type,
bool vocab_only) {
bool vocab_only,
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());

const int64_t t_start_us = ggml_time_us();
Expand Down Expand Up @@ -576,6 +580,10 @@ static bool llama_model_load(

std::vector<uint8_t> tmp;

if (progress_callback) {
progress_callback(0.0, progress_callback_user_data);
}

for (int i = 0; i < n_parts; ++i) {
const int part_id = i;
//const int part_id = n_parts - i - 1;
Expand All @@ -589,6 +597,10 @@ static bool llama_model_load(

fin = std::ifstream(fname_part, std::ios::binary);
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());

fin.seekg(0, fin.end);
const size_t file_size = fin.tellg();

fin.seekg(file_offset);

// load weights
Expand Down Expand Up @@ -764,6 +776,11 @@ static bool llama_model_load(
model.n_loaded++;

// progress
if (progress_callback) {
double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
double current_progress = (double(i) + current_file_progress) / double(n_parts);
progress_callback(current_progress, progress_callback_user_data);
}
if (model.n_loaded % 8 == 0) {
fprintf(stderr, ".");
fflush(stderr);
Expand All @@ -786,6 +803,10 @@ static bool llama_model_load(

lctx.t_load_us = ggml_time_us() - t_start_us;

if (progress_callback) {
progress_callback(1.0, progress_callback_user_data);
}

return true;
}

Expand Down Expand Up @@ -1617,7 +1638,8 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;

if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
params.vocab_only)) {
params.vocab_only, params.progress_callback,
params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx);
return nullptr;
Expand Down
7 changes: 7 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ extern "C" {

} llama_token_data;

typedef void (*llama_progress_callback)(double progress, void *ctx);

struct llama_context_params {
int n_ctx; // text context
int n_parts; // -1 for default
Expand All @@ -55,6 +57,11 @@ extern "C" {
bool vocab_only; // only load the vocabulary, no weights
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only

// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
};

LLAMA_API struct llama_context_params llama_context_default_params();
Expand Down