-
Notifications
You must be signed in to change notification settings - Fork 11.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama, main : save state incrementally #1310
Conversation
If the primary goal of a |
if (n_matching_session_tokens >= (int) embd_inp.size()) { | ||
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__); | ||
} else if (n_matching_session_tokens < (int) (embd_inp.size() / 2)) { | ||
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%d / %zu tokens); will mostly be reevaluated\n", | ||
__func__, n_matching_session_tokens, embd_inp.size()); | ||
} else { | ||
fprintf(stderr, "%s: session file matches %d / %zu tokens of prompt\n", | ||
__func__, n_matching_session_tokens, embd_inp.size()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer one line solution to display count of matching tokens. It is more verbose than these 9 lines of already bloated file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's fair
// save session after context swap | ||
if (!path_session.empty() && needs_swap) { | ||
int n_tokens = n_past - params.n_keep; | ||
if (!llama_append_session_file( | ||
ctx, path_session.c_str(), params.n_keep, | ||
last_n_tokens.data() + last_n_tokens.size() - n_tokens, n_tokens)) { | ||
fprintf(stderr, "%s: error: unable to write to session file '%s'\n", | ||
__func__, path_session.c_str()); | ||
return 1; | ||
} | ||
|
||
n_session_write_past = n_past; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And whole prompt should be >= size of n_ctx to save session, but using whole prompt next time will give us error "prompt is too long"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, per my other comment I'll probably drop this. To be clear though, if you're talking about what's saved in the session, the state changes are incremental. The original prompt state will always be there in the first segment of the session regardless of what comes after and that's all the startup code loads.
@DannyDaemonic yes, saving the swap does nothing for the prompt caching use case; I did that with an eye to things like restoring or replaying full sessions, and just generally for completeness. To clarify though, because this approach is incremental and state changes are appended, the original prompt state will be in there and usable regardless. When we load the session, we only read far enough (i.e., replay state changes) to match the original prompt. That said, on second thought, I think session restoration with a swap might require more thought. So maybe I'll omit it for now. |
Also, keep in mind that it's first use is prompt caching - if someone would change sampling parameters, then the next prediction is different. |
Closing this in favor of #1338, which just focuses on allowing (optionally now) to save the full output to the session. |
Updates get/set state and save/load session to support incremental changes to kv state. This lets us checkpoint the state after initial prompt eval and at the end of the session, saving only what changed. On reload, we only load what is needed, flexibly supporting both the "prompt cache" and "prompt continuation" use cases. Eventually, I envision this being used to e.g., sync state more frequently, or store branching evaluation states as deltas.
Changes
llama_copy_state_data
with an additional parameter specifying the token offset. Passing0
gives the current behavior. Otherwise, only thentok - offset
most recent tokens are serialized.llama_init_session_file
andllama_append_session_file
, the latter being called N times on successive batches of tokens and incremental state. The originalllama_save_session_file
is simply init + append once.llama_load_session_file
now stops reading after filling the provided token capacity (so loading the prompt from a larger session doesn't incur full cost)main
takes advantage of these changes by saving state after prompt eval, on context swap, and on exitApproach
The new session file format is the header followed by one or more segments of tokens + incremental state:
Each segment represents the incremental evaluation of successive sequences of tokens and the resulting change in state. The incremental state stores the final RNG and logits for that batch, the starting token offset, and the portion of KV state from that offset onward.
When loading saved state in
llama_load_session_file
, the caller provides the desired number of tokens (e.g., prompt length) to restore. As those tokens are read from successive segments, the state each segment is applied. Specifically, this means restoring the rng and logits and applying the KV state delta. Finally, if the requested tokens don't fit cleanly into a state segment, one less token is returned to force an evaluation and ensure correct logits at that position.I also took this opportunity to clean up the initial session code in
main
.Testing
chat-13B
on 30B: 35s cold vs 2s warmexamples/save-load-state
and non-session usage work