-
Notifications
You must be signed in to change notification settings - Fork 11.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama, main : save state incrementally #1310
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,9 +140,12 @@ int main(int argc, char ** argv) { | |
// Add a space in front of the first character to match OG llama tokenizer behavior | ||
params.prompt.insert(0, 1, ' '); | ||
|
||
std::string path_session = params.path_session; | ||
std::vector<llama_token> session_tokens; | ||
// tokenize the prompt | ||
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); | ||
|
||
// restore prompt from saved session | ||
const std::string path_session = params.path_session; | ||
int n_matching_session_tokens = 0; | ||
if (!path_session.empty()) { | ||
fprintf(stderr, "%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str()); | ||
|
||
|
@@ -151,49 +154,43 @@ int main(int argc, char ** argv) { | |
if (fp != NULL) { | ||
std::fclose(fp); | ||
|
||
session_tokens.resize(params.n_ctx); | ||
std::vector<llama_token> session_tokens(embd_inp.size()); | ||
size_t n_token_count_out = 0; | ||
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { | ||
fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); | ||
return 1; | ||
} | ||
session_tokens.resize(n_token_count_out); | ||
|
||
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); | ||
// find matching input prefix from saved session | ||
for (llama_token id : session_tokens) { | ||
if (n_matching_session_tokens >= (int) embd_inp.size() || id != embd_inp[n_matching_session_tokens]) { | ||
break; | ||
} | ||
n_matching_session_tokens++; | ||
} | ||
|
||
if (n_matching_session_tokens >= (int) embd_inp.size()) { | ||
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__); | ||
} else if (n_matching_session_tokens < (int) (embd_inp.size() / 2)) { | ||
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%d / %zu tokens); will mostly be reevaluated\n", | ||
__func__, n_matching_session_tokens, embd_inp.size()); | ||
} else { | ||
fprintf(stderr, "%s: session file matches %d / %zu tokens of prompt\n", | ||
__func__, n_matching_session_tokens, embd_inp.size()); | ||
} | ||
} else { | ||
fprintf(stderr, "%s: session file does not exist, will create\n", __func__); | ||
} | ||
} | ||
|
||
// tokenize the prompt | ||
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); | ||
|
||
const int n_ctx = llama_n_ctx(ctx); | ||
|
||
if ((int) embd_inp.size() > n_ctx - 4) { | ||
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); | ||
return 1; | ||
} | ||
|
||
// debug message about similarity of saved session, if applicable | ||
size_t n_matching_session_tokens = 0; | ||
if (session_tokens.size()) { | ||
for (llama_token id : session_tokens) { | ||
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) { | ||
break; | ||
} | ||
n_matching_session_tokens++; | ||
} | ||
if (n_matching_session_tokens >= embd_inp.size()) { | ||
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__); | ||
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) { | ||
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", | ||
__func__, n_matching_session_tokens, embd_inp.size()); | ||
} else { | ||
fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n", | ||
__func__, n_matching_session_tokens, embd_inp.size()); | ||
} | ||
} | ||
|
||
// number of tokens to keep when resetting context | ||
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) { | ||
|
@@ -283,16 +280,11 @@ int main(int argc, char ** argv) { | |
bool is_antiprompt = false; | ||
bool input_echo = true; | ||
|
||
// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session | ||
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the | ||
// initial prompt so it doesn't need to be an exact match. | ||
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4); | ||
|
||
|
||
int n_past = 0; | ||
int n_remain = params.n_predict; | ||
int n_consumed = 0; | ||
int n_session_consumed = 0; | ||
int n_past = 0; | ||
int n_remain = params.n_predict; | ||
int n_consumed = 0; | ||
int n_session_consumed = 0; | ||
int n_session_write_past = 0; | ||
|
||
// the first thing we will do is to output the prompt, so set color accordingly | ||
set_console_color(con_st, CONSOLE_COLOR_PROMPT); | ||
|
@@ -306,17 +298,15 @@ int main(int argc, char ** argv) { | |
// if we run out of context: | ||
// - take the n_keep first tokens from the original prompt (via n_past) | ||
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches | ||
if (n_past + (int) embd.size() > n_ctx) { | ||
bool needs_swap = n_past + (int) embd.size() > n_ctx; | ||
if (needs_swap) { | ||
const int n_left = n_past - params.n_keep; | ||
|
||
n_past = params.n_keep; | ||
|
||
// insert n_left/2 tokens at the start of embd from last_n_tokens | ||
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); | ||
|
||
// stop saving session if we run out of context | ||
path_session = ""; | ||
|
||
//printf("\n---\n"); | ||
//printf("resetting: '"); | ||
//for (int i = 0; i < (int) embd.size(); i++) { | ||
|
@@ -326,27 +316,12 @@ int main(int argc, char ** argv) { | |
//printf("\n---\n"); | ||
} | ||
|
||
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) | ||
// REVIEW | ||
if (n_session_consumed < (int) session_tokens.size()) { | ||
size_t i = 0; | ||
for ( ; i < embd.size(); i++) { | ||
if (embd[i] != session_tokens[n_session_consumed]) { | ||
session_tokens.resize(n_session_consumed); | ||
break; | ||
} | ||
|
||
n_past++; | ||
n_session_consumed++; | ||
|
||
if (n_session_consumed >= (int) session_tokens.size()) { | ||
++i; | ||
break; | ||
} | ||
} | ||
if (i > 0) { | ||
embd.erase(embd.begin(), embd.begin() + i); | ||
} | ||
// skip evaluation of tokens in the input prefix that matched session | ||
if (n_session_consumed < n_matching_session_tokens) { | ||
int n_skip = std::min((int) embd.size(), n_matching_session_tokens - n_session_consumed); | ||
embd.erase(embd.begin(), embd.begin() + n_skip); | ||
n_session_consumed += n_skip; | ||
n_past += n_skip; | ||
} | ||
|
||
// evaluate tokens in batches | ||
|
@@ -363,14 +338,42 @@ int main(int argc, char ** argv) { | |
n_past += n_eval; | ||
} | ||
|
||
if (embd.size() > 0 && !path_session.empty()) { | ||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); | ||
n_session_consumed = session_tokens.size(); | ||
// save session after context swap | ||
if (!path_session.empty() && needs_swap) { | ||
int n_tokens = n_past - params.n_keep; | ||
if (!llama_append_session_file( | ||
ctx, path_session.c_str(), params.n_keep, | ||
last_n_tokens.data() + last_n_tokens.size() - n_tokens, n_tokens)) { | ||
fprintf(stderr, "%s: error: unable to write to session file '%s'\n", | ||
__func__, path_session.c_str()); | ||
return 1; | ||
} | ||
|
||
n_session_write_past = n_past; | ||
Comment on lines
+341
to
+352
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And whole prompt should be >= size of n_ctx to save session, but using whole prompt next time will give us error "prompt is too long" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, per my other comment I'll probably drop this. To be clear though, if you're talking about what's saved in the session, the state changes are incremental. The original prompt state will always be there in the first segment of the session regardless of what comes after and that's all the startup code loads. |
||
} | ||
} | ||
|
||
embd.clear(); | ||
|
||
// save prompt evaluation state to session file | ||
if (!path_session.empty() && !n_session_write_past && (int) embd_inp.size() <= n_consumed) { | ||
if (!llama_init_session_file(ctx, path_session.c_str())) { | ||
fprintf(stderr, "%s: error: unable to start session file '%s'\n", | ||
__func__, path_session.c_str()); | ||
return 1; | ||
} | ||
|
||
if (!llama_append_session_file( | ||
ctx, path_session.c_str(), 0, | ||
last_n_tokens.data() + last_n_tokens.size() - n_past, n_past)) { | ||
fprintf(stderr, "%s: error: unable to write to session file '%s'\n", | ||
__func__, path_session.c_str()); | ||
return 1; | ||
} | ||
|
||
n_session_write_past = n_past; | ||
} | ||
|
||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { | ||
// out of user input, sample next token | ||
const float temp = params.temp; | ||
|
@@ -387,12 +390,6 @@ int main(int argc, char ** argv) { | |
const float mirostat_eta = params.mirostat_eta; | ||
const bool penalize_nl = params.penalize_nl; | ||
|
||
// optionally save the session on first sample (for faster prompt loading next time) | ||
if (!path_session.empty() && need_to_save_session) { | ||
need_to_save_session = false; | ||
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); | ||
} | ||
|
||
llama_token id = 0; | ||
|
||
{ | ||
|
@@ -608,6 +605,20 @@ int main(int argc, char ** argv) { | |
} | ||
} | ||
|
||
if (!path_session.empty()) { | ||
int n_session_remain = n_past - n_session_write_past; | ||
fprintf(stderr, "\n%s: saving remaining state (%d tokens) to session file '%s'", | ||
__func__, n_session_remain, path_session.c_str()); | ||
if (!llama_append_session_file( | ||
ctx, path_session.c_str(), n_session_write_past, | ||
last_n_tokens.data() + last_n_tokens.size() - embd.size() - n_session_remain, | ||
n_session_remain)) { | ||
fprintf(stderr, "%s: error: unable to write to session file '%s'\n", | ||
__func__, path_session.c_str()); | ||
return 1; | ||
} | ||
} | ||
|
||
llama_print_timings(ctx); | ||
llama_free(ctx); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer one line solution to display count of matching tokens. It is more verbose than these 9 lines of already bloated file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's fair