Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama, main : save state incrementally #1310

Closed
wants to merge 1 commit into from

Conversation

ejones
Copy link
Collaborator

@ejones ejones commented May 4, 2023

Updates get/set state and save/load session to support incremental changes to kv state. This lets us checkpoint the state after initial prompt eval and at the end of the session, saving only what changed. On reload, we only load what is needed, flexibly supporting both the "prompt cache" and "prompt continuation" use cases. Eventually, I envision this being used to e.g., sync state more frequently, or store branching evaluation states as deltas.

Changes

  • updates llama_copy_state_data with an additional parameter specifying the token offset. Passing 0 gives the current behavior. Otherwise, only the ntok - offset most recent tokens are serialized.
  • provides an incremental session saving API: llama_init_session_file and llama_append_session_file, the latter being called N times on successive batches of tokens and incremental state. The original llama_save_session_file is simply init + append once.
  • session file layout changes and version is bumped, more on that below
  • llama_load_session_file now stops reading after filling the provided token capacity (so loading the prompt from a larger session doesn't incur full cost)
  • main takes advantage of these changes by saving state after prompt eval, on context swap, and on exit

Approach

The new session file format is the header followed by one or more segments of tokens + incremental state:

<... magic, version, hparams > | ( <u32> token_count | <token_count> tokens | <size_t> state_size | state )*

Each segment represents the incremental evaluation of successive sequences of tokens and the resulting change in state. The incremental state stores the final RNG and logits for that batch, the starting token offset, and the portion of KV state from that offset onward.

When loading saved state in llama_load_session_file, the caller provides the desired number of tokens (e.g., prompt length) to restore. As those tokens are read from successive segments, the state each segment is applied. Specifically, this means restoring the rng and logits and applying the KV state delta. Finally, if the requested tokens don't fit cleanly into a state segment, one less token is returned to force an evaluation and ensure correct logits at that position.

I also took this opportunity to clean up the initial session code in main.

Testing

  1. Prompt startup time with session using chat-13B on 30B: 35s cold vs 2s warm
% ./examples/chat-13B.sh -m ~/llama-models/30B/ggml-model-q4_0.bin --session sessions/chat-1tok.30.bin -n 1
llama_print_timings:       total time = 34634.49 ms
% ./examples/chat-13B.sh -m ~/llama-models/30B/ggml-model-q4_0.bin --session sessions/chat-1tok.30.bin -n 1
llama_print_timings:       total time =  2308.30 ms
  1. Repeat prompt invocations on a session yield the same results:
% ./main -m ~/llama-models/30B/ggml-model-q4_0.bin --seed 1 --session sessions/meaning-life-30.bin -n 5 -p 'The meaning of life is 4'
...
 The meaning of life is 42: Douglas Adams
...
llama_print_timings:       total time =  2188.90 ms

% ./main -m ~/llama-models/30B/ggml-model-q4_0.bin --seed 1 --session sessions/meaning-life-30.bin -n 5 -p 'The meaning of life is 4'
...
 The meaning of life is 42: Douglas Adams
...
llama_print_timings:       total time =   1320.01 ms
  1. Growing prompt on successive invocations with session - run time is constant:
 % ./main -m ~/llama-models/30B/ggml-model-q4_0.bin --seed 1 --session sessions/meaning-life-30.bin -n 15 -p 'The meaning of life is 4'
...
 The meaning of life is 42: Douglas Adams
The Hitchhiker's Guide to the
...
llama_print_timings:       total time =  3978.35 ms

% ./main -m ~/llama-models/30B/ggml-model-q4_0.bin --seed 1 --session sessions/meaning-life-30.bin -n 15 -p 'The meaning of life is 42: Douglas Adams
The Hitchhiker'\''s Guide to the'
...
 The meaning of life is 42: Douglas Adams
The Hitchhiker's Guide to the Galaxy, as it appears in the computer game adaptation of the series.
...
llama_print_timings:       total time =  3331.50 ms

% ./main -m ~/llama-models/30B/ggml-model-q4_0.bin --seed 1 --session sessions/meaning-life-30.bin -n 15 -p 'The meaning of life is 42: Douglas Adams
The Hitchhiker'\''s Guide to the Galaxy, as it appears in the computer game adaptation of the series.'
...
 The meaning of life is 42: Douglas Adams
The Hitchhiker's Guide to the Galaxy, as it appears in the computer game adaptation of the series.
It’s been a few years since I last read Douglas Adams
...
llama_print_timings:       total time =  3379.11 ms
  1. Rerunning a prefix of saved prompt gives sensible results:
 % ./main -m ~/llama-models/30B/ggml-model-q4_0.bin --seed 1 --session sessions/meaning-life-30.bin -n 15 -p 'The meaning of life is 4'
...
 The meaning of life is 42: Douglas Adams...

 % ./main -m ~/llama-models/30B/ggml-model-q4_0.bin --seed 1 --session sessions/meaning-life-30.bin -n 15 -p 'The meaning of life ' 
...
 The meaning of life  and whether it has any purpose at all 
...
  1. Session size is still reasonable (but will now scale with total number of tokens)
 % du -hs sessions/*
768M	sessions/chat-1tok.30.bin
 31M	sessions/meaning-life-30.bin
  1. examples/save-load-state and non-session usage work

@DannyDaemonic
Copy link
Contributor

If the primary goal of a session file is to speed up the initial processing, we may want to simply disable updating the session file once we have to reset the context. When the context fills up, it isn't being compressed or anything particularly clever; we're simply removing half of the context and reevaluating the second half starting at the beginning of the context window so there's room to continue. If you save after the context has been reset, the next time we start with the same prompt, we won't be able to use any of that context (since the first half was dropped).

Comment on lines +173 to +181
if (n_matching_session_tokens >= (int) embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (int) (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%d / %zu tokens); will mostly be reevaluated\n",
__func__, n_matching_session_tokens, embd_inp.size());
} else {
fprintf(stderr, "%s: session file matches %d / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer one line solution to display count of matching tokens. It is more verbose than these 9 lines of already bloated file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's fair

Comment on lines +341 to +352
// save session after context swap
if (!path_session.empty() && needs_swap) {
int n_tokens = n_past - params.n_keep;
if (!llama_append_session_file(
ctx, path_session.c_str(), params.n_keep,
last_n_tokens.data() + last_n_tokens.size() - n_tokens, n_tokens)) {
fprintf(stderr, "%s: error: unable to write to session file '%s'\n",
__func__, path_session.c_str());
return 1;
}

n_session_write_past = n_past;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And whole prompt should be >= size of n_ctx to save session, but using whole prompt next time will give us error "prompt is too long"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, per my other comment I'll probably drop this. To be clear though, if you're talking about what's saved in the session, the state changes are incremental. The original prompt state will always be there in the first segment of the session regardless of what comes after and that's all the startup code loads.

@ejones
Copy link
Collaborator Author

ejones commented May 4, 2023

@DannyDaemonic yes, saving the swap does nothing for the prompt caching use case; I did that with an eye to things like restoring or replaying full sessions, and just generally for completeness. To clarify though, because this approach is incremental and state changes are appended, the original prompt state will be in there and usable regardless. When we load the session, we only read far enough (i.e., replay state changes) to match the original prompt.

That said, on second thought, I think session restoration with a swap might require more thought. So maybe I'll omit it for now.

@ivanstepanovftw
Copy link
Collaborator

ivanstepanovftw commented May 4, 2023

Also, keep in mind that it's first use is prompt caching - if someone would change sampling parameters, then the next prediction is different.

@ejones
Copy link
Collaborator Author

ejones commented May 6, 2023

Closing this in favor of #1338, which just focuses on allowing (optionally now) to save the full output to the session.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants