Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Implementation of prompt caching #14

Merged
merged 17 commits into from
Mar 17, 2023
Merged

Implementation of prompt caching #14

merged 17 commits into from
Mar 17, 2023

Conversation

setzer22
Copy link
Collaborator

@setzer22 setzer22 commented Mar 15, 2023

This was simpler than I expected, so time to sneak in a little PR :)

This implements the same feature described in ggerganov/llama.cpp#64

Basically, this adds an API to access the memory tensors in the model, and adds a couple functions to save them and load them from disk.

I also added --cache-prompt and --restore-prompt flags for llama-cli:

  • When running with --cache-prompt mode, no inference is run, and the program dumps the memory contents into a file after feeding it the prompt.
  • When running with --restore-prompt, the contents of the memory are read from the disk at the given path, and the prompt you feed the system is concatenated, right after the cached prompt.

Note this PR builds on top of #10, so the diff will not look correct until that is merged. You have to filter to only look at the commits from this PR when you go to the "Files changed" tab.

I have tested the changes, but since this adds a non-trivial change, I'd like to make sure I didn't break anything before merging. Anyone interested please pull the branch and test it before we commit. Here's some minimal instructions.

  1. Run the model with an incomplete prompt, you can use the example below as an example:
RUSTFLAGS='-C target-feature=+avx2,+fma,+f16c'  cargo run --release -- -m /data/Llama/LLaMA/7B/ggml-model-q4_0.bin -f <path_to_prompt_text_file> --cache-prompt <path_to_cache_file>
The following text is a transcript for a conversation between a human user and The Assistant. The Assistant is a smart, reliable, caring, confident chatbot which is based on advanced AI technology and is capable of replying to the user's messages in a truthful and understanding way.

The transcript consists of an exchange of messages. User messages begin with the word USER, while The Assistant's messages start by using the word ASSISTANT.

=== BEGIN TRANSCRIPT ===
USER: Explain what a game engine is like a 5 year old
ASSISTANT:
  1. Run the command again, this time passing the --restore-cache flag and a smaller prompt
RUSTFLAGS='-C target-feature=+avx2,+fma,+f16c'  cargo run --release -- -m /data/Llama/LLaMA/7B/ggml-model-q4_0.bin -p "A game engine" --restore-prompt <path_to_cache_file> 

The observed behavior should be that the second time, the system starts predicting just where it left off, and you only pay for the time it takes to parse the extra prompt :)

@setzer22 setzer22 changed the title Implement prompt caching Implementation of prompt caching Mar 15, 2023
We don't do compression, because it's slower (compression ratios were
quite good though).
Copy link
Collaborator Author

@setzer22 setzer22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments as a self-review (for those new to the repo, I always try to do this. I think it helps others understand the changes 😄)

.expect("Could not load model");

if let Some(restore_path) = &args.restore_prompt {
let memory = ModelMemory::load_compressed(restore_path);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can load a stored memory file from disk like this. I chose to add a simple API with bincode, but the struct implements De/Serialize so you can do your own serialization instead.

Comment on lines 73 to 76
// SAFETY: no other model functions used inside the block
unsafe {
let memory = model.get_memory();
match memory.write_compressed(cache_path) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this is how you access tensor memory and write it to a file.

This returns a reference to tensor memory inside the inner context, so it's not safe to access that and then keep using the model object because that might mutate memory underneath and cause UB. I just decided to mark the function unsafe out of caution, but there's a whole underlying safety issue here that the ggml bindings allow mutating via a shared ref &. (more on that below)

Comment on lines 614 to 619
if n_tensors % 8 == 0 {
print!(".");
std::io::stdout().flush();
}
}
println!(".");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brought this back temporarily because I was going crazy without it 😄 I'll just end up rebasing on whatever we agree for as part of #10.

Comment on lines 641 to 642
let beginning_of_sentence = self.n_past == 0;
let embd_inp = self.tokenize(vocab, prompt, beginning_of_sentence);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One sneaky issue I found was that we don't want to feed the beginning of sentence token when restoring from memory, because that causes the system to insert that token into the memory and then the system ignores any tokens before it.

Comment on lines 606 to 612
let _ = self.llama_eval(
params.n_threads,
0,
&[0, 1, 2, 3],
&mut logits,
&mut mem_per_token,
);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one important change that might (?) cause issues on the larger models. The way llama_eval works is that when you pass mem_per_token as zero, it computes the memory and sets it back for you so you can pass it after successive calls.

But this bogus call caused issues when we were resuming from a cached prompt (because it overwrites some memory tokens). So I decided to simply remove the call. This apparently works just fine, and thinking about it, I'm not even sure what the intent of the original code was, but let's make extra sure just in case 🤔

&embd,
&mut logits,
&mut mem_per_token,
);
}

n_past += embd.len() as i32;
self.n_past += embd.len();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_past is now stored inside the LlamaModel object. This allows invoking the generation multiple times. Wen n_past is > 0, inference continues by adding any extra prompt you feed it, and then starts inferring normally.

This could also serve as the basis for a more complex chat-like interaction :)

Comment on lines 679 to 685

// TODO: The logic is super convoluted already and this only
// makes it worse... Refactor!
if stop_after_prompt {
break;
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was necessary because we don't want to do any inference when caching prompts, just stop right after the prompt has been fed to the system.

Adding this break statement here made me think how much convoluted this logic is already. But we can fix that in a later PR

@setzer22
Copy link
Collaborator Author

Oh, before I forget. I also tried using the snap crate for compression. Some quick results:

  • Compression ratios for the prompt I'm sharing above look good. A cached prompt is around ~500MB, but only 250MB when compressed. However, this compression ratio may just be an artifact from having a ton of zeroes at the end of the (unfilled) memory.
  • Loading speed is negligible for uncompressed data, but was quite noticeable (~2 seconds) when compressed.

So, all things considered, I decided to not bother with compression right now. We can always add it later.

@setzer22 setzer22 mentioned this pull request Mar 15, 2023
6 tasks
Copy link
Collaborator

@philpax philpax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Would this also enable continuing an existing generation that ran out before its end-of-text?

@setzer22
Copy link
Collaborator Author

setzer22 commented Mar 16, 2023

Looks good to me. Would this also enable continuing an existing generation that ran out before its end-of-text?

Not directly, but it would be a step in the right direction. Basically, we would need to manipulate the memory to drop the first tokens so that we make up space to continue generation.

I'm thinking, adding an API for manipulating the tokens in the memory would lead to very interesting use cases for experimentation 🤔

With such an API you could, e.g. keep the beginning of the conversation, insert some (...) tokens in the middle, then a part in the end, and you made some room to continue the generation but the model still kinda keeps the context for the conversation. We need to try this!

Copy link
Contributor

@mwbryant mwbryant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great new feature, I had some thoughts about the API but the functionality seems like it will open a lot of possibilities.

Comment on lines 632 to 639
pub fn inference_with_prompt<'a>(
&mut self,
vocab: &'a GptVocab,
params: &InferenceParams,
prompt: &str,
rng: &mut impl rand::Rng,
stop_after_prompt: bool,
callback: impl Fn(OutputToken<'a>),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a new parameter to this function would it be better to break this out into 2 separate functions. One to feed the model prompt data and one to run an inference. A bool here doesn't feel like the best API to me. The model would then need to track its own last_n_tokens and maybe more but that doesn't seem unreasonable to me.

We might then be able to get away from the callback and just have the inference return the tokens, giving the user more control over when to run inferencing instead of forcing it to run to completion. Putting the core loop back into user code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I was considering was exposing more of the raw inference loop, but wrapping it up in APIs that make it straightforward to use, so that people can choose the behaviour they need. Not sure what that API would look like yet (would need to look at it after work), but I would expect it to enable prompt caching and continued inference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds great! I'm coming from the Bevy world so a good API in my opinion gives short running functions that does the minimal unit of work and then gives control back to the caller. The callback system is inverted to how I would want to use this in projects.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I added the callback as the smallest-change-possible solution to the existing mechanism. I don't want to expose too much of the details around actually driving the model, but I would like to enable users to dig around.

I think the easiest way to experiment with this is to replicate the loop in user code (I think all of the relevant functions are public) and seeing what we can hide behind the scenes while still allowing for flexibility. Suggestions welcome!

Copy link
Collaborator Author

@setzer22 setzer22 Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like this: A Model::feed_prompt method you can use to feed the model a sentence. This would not perform any inference, just the prompt.

Then, another method Model::next_token that given the current state of the model memory, returns the next token with no external input and modifies the internal memory. You can call this as many times as you want on a loop, react to the results, and stop whenever you want. Could be easily wrapped into an Iterator too, but I wouldn't go that far.

Finally, we can keep the original infer_with_prompt as a trivial wrapper over the two previous methods (I like backwards compatibility even just the code is 3 days old lol 🤣).

This new structure would help a lot in untangling the current mess that infer_with_prompt is, and I don't see any downsides to it.

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me! That also solves a problem I've had with llamacord - there's no way to stop generation when it starts. Having manual control over fetching the next token should sort that out.

Would you be able to feed a prompt after calling next_token?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, feed prompt or next token could be interleaved in any way the user wants. 😄

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reading up on https://til.simonwillison.net/llms/python-react-pattern and I'm now very excited for the potential of this new API. Are you planning on building it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you planning on building it?

The API? Yes 😁 It's up already!

I had no direct plans, but I'd be super interested in seeing ReAct working with LLaMA using this 👀

@setzer22
Copy link
Collaborator Author

setzer22 commented Mar 17, 2023

Alright, this was some major cleanup session! 😄

As discussed, I've broken down the old infer_with_prompt function into two functions: feed_prompt and infer_next_token. This helps untangle the mess the original function was. A Model now offers a "poll"-based API. When you want to produce a token, or feed some more prompt, you just call it. It is now up to the caller to stop on "end of text" or do anything else entirely. 😄

As part of this refactor, I've also had to move some of the things that were previously declared inside infer_with_prompt into the Model struct. Functions now have less arguments, and rely more on updating internal state. Things like last_n_tokens or n_past are now stored as part of the model.

Finally, I've refactored the Memory structs from the previous iteration of this PR into InferenceSnapshot, which stores more things than just the memory (for example, the last_n_tokens and the logits vector). This ensures inference will resume at the exact same point where it left off when restoring a snapshot.

Copy link
Collaborator

@philpax philpax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks good! Just going to make a few suggestions :)

// Write the memory to the cache file
// SAFETY: no other model functions used inside the block
unsafe {
let memory = model.get_snapshot();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to make this safe? Something like memory.with_snapshot(|snapshot| {}) so that access is bounded

Copy link
Collaborator Author

@setzer22 setzer22 Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, this "unsafe" is here out of caution. This function should not be marked unsafe, because the returned object has the correct lifetime and prevents mutating the model for as long as it lives.

The problem here is that tensors inside ggml can be mutated through a shared ref &. When writing the function to return the slice I realized all this and thought: "Let's just mark this as unsafe for the time being until we figure a better way".

The whole "safety-ness" of the library needs to be reviewed, but I'd do that in another PR or we'll end up putting too much stuff on this one. For now, let's just push the requirements to the caller on this one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, let's revisit later!

llama-rs/src/lib.rs Outdated Show resolved Hide resolved
llama-rs/src/lib.rs Outdated Show resolved Hide resolved
llama-rs/src/lib.rs Outdated Show resolved Hide resolved
llama-rs/src/lib.rs Outdated Show resolved Hide resolved

/// Clears the model's state, so inference can be restarted without having
/// to load the weights again.
pub fn clear_state(&mut self) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of this - it seems footgun-y (like it'd be easy to forget to clear the state after you're done). I'd suggest decoupling inference from the model itself, so that you have something like

let mut session: InferenceSession = model.infer();
session.feed_prompt("blah");

where all of the state relevant to this inference is in InferenceSession, and the model is purely data. (Bonus points if you have it borrow from Model.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm, you're right 🤔 This requires moving the memory_k and memory_v tensors into another struct though. They should probably be owned by an entirely separate ggml context stored inside the InferenceSession to make things easier.


if eot {
while self.n_past < self.hparams.n_ctx as usize {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop seems a little weird to me. Maybe have a comment explaining what it's doing for people who look at it to figure out how to structure their own loops?

Copy link
Collaborator Author

@setzer22 setzer22 Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to continue inferring tokens until the model returns an EndOfText, or we run out of space. Whatever happens first. I added an explanatory comment 👍

llama-rs/src/lib.rs Outdated Show resolved Hide resolved
llama-rs/src/lib.rs Outdated Show resolved Hide resolved
llama-rs/src/lib.rs Outdated Show resolved Hide resolved
@xingchensong
Copy link

great job!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants