Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama, main : save state incrementally #1310

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 81 additions & 70 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,12 @@ int main(int argc, char ** argv) {
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');

std::string path_session = params.path_session;
std::vector<llama_token> session_tokens;
// tokenize the prompt
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);

// restore prompt from saved session
const std::string path_session = params.path_session;
int n_matching_session_tokens = 0;
if (!path_session.empty()) {
fprintf(stderr, "%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str());

Expand All @@ -151,49 +154,43 @@ int main(int argc, char ** argv) {
if (fp != NULL) {
std::fclose(fp);

session_tokens.resize(params.n_ctx);
std::vector<llama_token> session_tokens(embd_inp.size());
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
session_tokens.resize(n_token_count_out);

fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
// find matching input prefix from saved session
for (llama_token id : session_tokens) {
if (n_matching_session_tokens >= (int) embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
break;
}
n_matching_session_tokens++;
}

if (n_matching_session_tokens >= (int) embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (int) (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%d / %zu tokens); will mostly be reevaluated\n",
__func__, n_matching_session_tokens, embd_inp.size());
} else {
fprintf(stderr, "%s: session file matches %d / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
Comment on lines +173 to +181
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer one line solution to display count of matching tokens. It is more verbose than these 9 lines of already bloated file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's fair

} else {
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
}
}

// tokenize the prompt
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);

const int n_ctx = llama_n_ctx(ctx);

if ((int) embd_inp.size() > n_ctx - 4) {
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;
}

// debug message about similarity of saved session, if applicable
size_t n_matching_session_tokens = 0;
if (session_tokens.size()) {
for (llama_token id : session_tokens) {
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
__func__, n_matching_session_tokens, embd_inp.size());
} else {
fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
}

// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) {
Expand Down Expand Up @@ -283,16 +280,11 @@ int main(int argc, char ** argv) {
bool is_antiprompt = false;
bool input_echo = true;

// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
// initial prompt so it doesn't need to be an exact match.
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);


int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;
int n_session_consumed = 0;
int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;
int n_session_consumed = 0;
int n_session_write_past = 0;

// the first thing we will do is to output the prompt, so set color accordingly
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
Expand All @@ -306,17 +298,15 @@ int main(int argc, char ** argv) {
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() > n_ctx) {
bool needs_swap = n_past + (int) embd.size() > n_ctx;
if (needs_swap) {
const int n_left = n_past - params.n_keep;

n_past = params.n_keep;

// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());

// stop saving session if we run out of context
path_session = "";

//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
Expand All @@ -326,27 +316,12 @@ int main(int argc, char ** argv) {
//printf("\n---\n");
}

// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
// REVIEW
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed);
break;
}

n_past++;
n_session_consumed++;

if (n_session_consumed >= (int) session_tokens.size()) {
++i;
break;
}
}
if (i > 0) {
embd.erase(embd.begin(), embd.begin() + i);
}
// skip evaluation of tokens in the input prefix that matched session
if (n_session_consumed < n_matching_session_tokens) {
int n_skip = std::min((int) embd.size(), n_matching_session_tokens - n_session_consumed);
embd.erase(embd.begin(), embd.begin() + n_skip);
n_session_consumed += n_skip;
n_past += n_skip;
}

// evaluate tokens in batches
Expand All @@ -363,14 +338,42 @@ int main(int argc, char ** argv) {
n_past += n_eval;
}

if (embd.size() > 0 && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
// save session after context swap
if (!path_session.empty() && needs_swap) {
int n_tokens = n_past - params.n_keep;
if (!llama_append_session_file(
ctx, path_session.c_str(), params.n_keep,
last_n_tokens.data() + last_n_tokens.size() - n_tokens, n_tokens)) {
fprintf(stderr, "%s: error: unable to write to session file '%s'\n",
__func__, path_session.c_str());
return 1;
}

n_session_write_past = n_past;
Comment on lines +341 to +352
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And whole prompt should be >= size of n_ctx to save session, but using whole prompt next time will give us error "prompt is too long"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, per my other comment I'll probably drop this. To be clear though, if you're talking about what's saved in the session, the state changes are incremental. The original prompt state will always be there in the first segment of the session regardless of what comes after and that's all the startup code loads.

}
}

embd.clear();

// save prompt evaluation state to session file
if (!path_session.empty() && !n_session_write_past && (int) embd_inp.size() <= n_consumed) {
if (!llama_init_session_file(ctx, path_session.c_str())) {
fprintf(stderr, "%s: error: unable to start session file '%s'\n",
__func__, path_session.c_str());
return 1;
}

if (!llama_append_session_file(
ctx, path_session.c_str(), 0,
last_n_tokens.data() + last_n_tokens.size() - n_past, n_past)) {
fprintf(stderr, "%s: error: unable to write to session file '%s'\n",
__func__, path_session.c_str());
return 1;
}

n_session_write_past = n_past;
}

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
const float temp = params.temp;
Expand All @@ -387,12 +390,6 @@ int main(int argc, char ** argv) {
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;

// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session) {
need_to_save_session = false;
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}

llama_token id = 0;

{
Expand Down Expand Up @@ -608,6 +605,20 @@ int main(int argc, char ** argv) {
}
}

if (!path_session.empty()) {
int n_session_remain = n_past - n_session_write_past;
fprintf(stderr, "\n%s: saving remaining state (%d tokens) to session file '%s'",
__func__, n_session_remain, path_session.c_str());
if (!llama_append_session_file(
ctx, path_session.c_str(), n_session_write_past,
last_n_tokens.data() + last_n_tokens.size() - embd.size() - n_session_remain,
n_session_remain)) {
fprintf(stderr, "%s: error: unable to write to session file '%s'\n",
__func__, path_session.c_str());
return 1;
}
}

llama_print_timings(ctx);
llama_free(ctx);

Expand Down
2 changes: 1 addition & 1 deletion examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ int main(int argc, char ** argv) {
// Save state (rng, logits, embedding and kv_cache) to file
{
FILE *fp_write = fopen("dump_state.bin", "wb");
llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file
llama_copy_state_data(ctx, state_mem, 0); // could also copy directly to memory mapped file
fwrite(state_mem, 1, state_size, fp_write);
fclose(fp_write);
}
Expand Down
Loading