Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : fix session load / save #1263

Merged
merged 1 commit into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,22 @@ int main(int argc, char ** argv) {
std::vector<llama_token> session_tokens;

if (!path_session.empty()) {
fprintf(stderr, "%s: attempting to load saved session from %s..\n", __func__, path_session.c_str());
fprintf(stderr, "%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str());

// REVIEW - fopen to check for existing session
// fopen to check for existing session
FILE * fp = std::fopen(path_session.c_str(), "rb");
if (fp != NULL) {
std::fclose(fp);

session_tokens.resize(params.n_ctx);
size_t n_token_count_out = 0;
const size_t n_session_bytes = llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out);
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
session_tokens.resize(n_token_count_out);

if (n_session_bytes > 0) {
fprintf(stderr, "%s: loaded %zu bytes of session data!\n", __func__, n_session_bytes);
} else {
fprintf(stderr, "%s: could not load session file, will recreate\n", __func__);
}
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
} else {
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
}
Expand Down Expand Up @@ -214,7 +213,7 @@ int main(int argc, char ** argv) {
}

// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) {
params.n_keep = (int)embd_inp.size();
}

Expand Down Expand Up @@ -329,7 +328,7 @@ int main(int argc, char ** argv) {
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());

// REVIEW - stop saving session if we run out of context
// stop saving session if we run out of context
path_session = "";

//printf("\n---\n");
Expand All @@ -355,6 +354,7 @@ int main(int argc, char ** argv) {
n_session_consumed++;

if (n_session_consumed >= (int) session_tokens.size()) {
++i;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the "off-by-one" bug causing different outputs, observed in #1257

break;
}
}
Expand Down
133 changes: 79 additions & 54 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,85 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
return nread;
}

bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
llama_file file(path_session, "rb");

// sanity checks
{
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();

if (!(magic == LLAMA_SESSION_MAGIC && version == LLAMA_SESSION_VERSION)) {
fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return false;
}

llama_hparams session_hparams;
file.read_raw(&session_hparams, sizeof(llama_hparams));

if (session_hparams != ctx->model.hparams) {
fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__);
return false;
}
}

// load the prompt
{
const uint32_t n_token_count = file.read_u32();

if (n_token_count > n_token_capacity) {
fprintf(stderr, "%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return false;
}

file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
}

// restore the context state
{
const size_t n_state_size_cur = file.size - file.tell();
const size_t n_state_size_exp = llama_get_state_size(ctx);

if (n_state_size_cur != n_state_size_exp) {
fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
return false;
}

std::vector<uint8_t> state_data(n_state_size_cur);
file.read_raw(state_data.data(), n_state_size_cur);

llama_set_state_data(ctx, state_data.data());
}

return true;
}

bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
llama_file file(path_session, "wb");

file.write_u32(LLAMA_SESSION_MAGIC);
file.write_u32(LLAMA_SESSION_VERSION);

file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));

// save the prompt
file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);

// save the context state
{
const size_t n_state_size = llama_get_state_size(ctx);

std::vector<uint8_t> state_data(n_state_size);
llama_copy_state_data(ctx, state_data.data());

file.write_raw(state_data.data(), n_state_size);
}

return true;
}

int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
Expand Down Expand Up @@ -2694,57 +2773,3 @@ const char * llama_print_system_info(void) {
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx) {
return ctx->model.tensors_by_name;
}

size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
// TODO leverage mmap
llama_file file(path_session, "rb");
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();

if (!(magic == 'ggsn' && version == 0)) {
fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return 0;
}

llama_hparams session_hparams;
file.read_raw(&session_hparams, sizeof(llama_hparams));

// REVIEW
if (session_hparams != ctx->model.hparams) {
fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__);
return 0;
}

const uint32_t n_token_count = file.read_u32();
LLAMA_ASSERT(n_token_capacity >= n_token_count);
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;

const size_t n_state_size = file.size - file.tell();
const size_t n_orig_state_size = llama_get_state_size(ctx);
if (n_state_size != n_orig_state_size) {
fprintf(stderr, "%s : failed to validate state size\n", __func__);
}
std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]);
file.read_raw(state_data.get(), n_state_size);
return llama_set_state_data(ctx, state_data.get());
}

size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
// TODO save temp & swap
llama_file file(path_session, "wb");

const size_t n_state_size = llama_get_state_size(ctx);
std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]);
llama_copy_state_data(ctx, state_data.get());

file.write_u32('ggsn'); // magic
file.write_u32(0); // version
file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));

file.write_u32((uint32_t) n_token_count); // REVIEW
file.write_raw(tokens, sizeof(llama_token) * n_token_count);

file.write_raw(state_data.get(), n_state_size);
return n_state_size; // REVIEW
}
12 changes: 7 additions & 5 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
# define LLAMA_API
#endif

#define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
#define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 'ggjt'
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
#define LLAMA_SESSION_MAGIC 'ggsn'
#define LLAMA_SESSION_VERSION 0

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -138,8 +140,8 @@ extern "C" {
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);

// Save/load session file
LLAMA_API size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);

// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
Expand Down