-
Notifications
You must be signed in to change notification settings - Fork 369
Conversation
We don't do compression, because it's slower (compression ratios were quite good though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments as a self-review (for those new to the repo, I always try to do this. I think it helps others understand the changes 😄)
llama-cli/src/main.rs
Outdated
.expect("Could not load model"); | ||
|
||
if let Some(restore_path) = &args.restore_prompt { | ||
let memory = ModelMemory::load_compressed(restore_path); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can load a stored memory file from disk like this. I chose to add a simple API with bincode
, but the struct implements De
/Serialize
so you can do your own serialization instead.
llama-cli/src/main.rs
Outdated
// SAFETY: no other model functions used inside the block | ||
unsafe { | ||
let memory = model.get_memory(); | ||
match memory.write_compressed(cache_path) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this is how you access tensor memory and write it to a file.
This returns a reference to tensor memory inside the inner context, so it's not safe to access that and then keep using the model
object because that might mutate memory underneath and cause UB. I just decided to mark the function unsafe
out of caution, but there's a whole underlying safety issue here that the ggml
bindings allow mutating via a shared ref &
. (more on that below)
llama-rs/src/llama.rs
Outdated
if n_tensors % 8 == 0 { | ||
print!("."); | ||
std::io::stdout().flush(); | ||
} | ||
} | ||
println!("."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Brought this back temporarily because I was going crazy without it 😄 I'll just end up rebasing on whatever we agree for as part of #10.
llama-rs/src/llama.rs
Outdated
let beginning_of_sentence = self.n_past == 0; | ||
let embd_inp = self.tokenize(vocab, prompt, beginning_of_sentence); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One sneaky issue I found was that we don't want to feed the beginning of sentence token when restoring from memory, because that causes the system to insert that token into the memory and then the system ignores any tokens before it.
llama-rs/src/llama.rs
Outdated
let _ = self.llama_eval( | ||
params.n_threads, | ||
0, | ||
&[0, 1, 2, 3], | ||
&mut logits, | ||
&mut mem_per_token, | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one important change that might (?) cause issues on the larger models. The way llama_eval
works is that when you pass mem_per_token as zero, it computes the memory and sets it back for you so you can pass it after successive calls.
But this bogus call caused issues when we were resuming from a cached prompt (because it overwrites some memory tokens). So I decided to simply remove the call. This apparently works just fine, and thinking about it, I'm not even sure what the intent of the original code was, but let's make extra sure just in case 🤔
llama-rs/src/llama.rs
Outdated
&embd, | ||
&mut logits, | ||
&mut mem_per_token, | ||
); | ||
} | ||
|
||
n_past += embd.len() as i32; | ||
self.n_past += embd.len(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_past
is now stored inside the LlamaModel
object. This allows invoking the generation multiple times. Wen n_past is > 0, inference continues by adding any extra prompt you feed it, and then starts inferring normally.
This could also serve as the basis for a more complex chat-like interaction :)
llama-rs/src/llama.rs
Outdated
|
||
// TODO: The logic is super convoluted already and this only | ||
// makes it worse... Refactor! | ||
if stop_after_prompt { | ||
break; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was necessary because we don't want to do any inference when caching prompts, just stop right after the prompt has been fed to the system.
Adding this break statement here made me think how much convoluted this logic is already. But we can fix that in a later PR
Oh, before I forget. I also tried using the snap crate for compression. Some quick results:
So, all things considered, I decided to not bother with compression right now. We can always add it later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Would this also enable continuing an existing generation that ran out before its end-of-text?
Not directly, but it would be a step in the right direction. Basically, we would need to manipulate the memory to drop the first tokens so that we make up space to continue generation. I'm thinking, adding an API for manipulating the tokens in the memory would lead to very interesting use cases for experimentation 🤔 With such an API you could, e.g. keep the beginning of the conversation, insert some |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great new feature, I had some thoughts about the API but the functionality seems like it will open a lot of possibilities.
llama-rs/src/llama.rs
Outdated
pub fn inference_with_prompt<'a>( | ||
&mut self, | ||
vocab: &'a GptVocab, | ||
params: &InferenceParams, | ||
prompt: &str, | ||
rng: &mut impl rand::Rng, | ||
stop_after_prompt: bool, | ||
callback: impl Fn(OutputToken<'a>), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding a new parameter to this function would it be better to break this out into 2 separate functions. One to feed the model prompt data and one to run an inference. A bool here doesn't feel like the best API to me. The model would then need to track its own last_n_tokens and maybe more but that doesn't seem unreasonable to me.
We might then be able to get away from the callback and just have the inference return the tokens, giving the user more control over when to run inferencing instead of forcing it to run to completion. Putting the core loop back into user code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I was considering was exposing more of the raw inference loop, but wrapping it up in APIs that make it straightforward to use, so that people can choose the behaviour they need. Not sure what that API would look like yet (would need to look at it after work), but I would expect it to enable prompt caching and continued inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds great! I'm coming from the Bevy world so a good API in my opinion gives short running functions that does the minimal unit of work and then gives control back to the caller. The callback system is inverted to how I would want to use this in projects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I added the callback as the smallest-change-possible solution to the existing mechanism. I don't want to expose too much of the details around actually driving the model, but I would like to enable users to dig around.
I think the easiest way to experiment with this is to replicate the loop in user code (I think all of the relevant functions are public) and seeing what we can hide behind the scenes while still allowing for flexibility. Suggestions welcome!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about something like this: A Model::feed_prompt
method you can use to feed the model a sentence. This would not perform any inference, just the prompt.
Then, another method Model::next_token
that given the current state of the model memory, returns the next token with no external input and modifies the internal memory. You can call this as many times as you want on a loop, react to the results, and stop whenever you want. Could be easily wrapped into an Iterator
too, but I wouldn't go that far.
Finally, we can keep the original infer_with_prompt
as a trivial wrapper over the two previous methods (I like backwards compatibility even just the code is 3 days old lol 🤣).
This new structure would help a lot in untangling the current mess that infer_with_prompt
is, and I don't see any downsides to it.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to me! That also solves a problem I've had with llamacord - there's no way to stop generation when it starts. Having manual control over fetching the next token should sort that out.
Would you be able to feed a prompt after calling next_token
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, feed prompt or next token could be interleaved in any way the user wants. 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm reading up on https://til.simonwillison.net/llms/python-react-pattern and I'm now very excited for the potential of this new API. Are you planning on building it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you planning on building it?
The API? Yes 😁 It's up already!
I had no direct plans, but I'd be super interested in seeing ReAct working with LLaMA using this 👀
Alright, this was some major cleanup session! 😄 As discussed, I've broken down the old As part of this refactor, I've also had to move some of the things that were previously declared inside Finally, I've refactored the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation looks good! Just going to make a few suggestions :)
llama-cli/src/main.rs
Outdated
// Write the memory to the cache file | ||
// SAFETY: no other model functions used inside the block | ||
unsafe { | ||
let memory = model.get_snapshot(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to make this safe? Something like memory.with_snapshot(|snapshot| {})
so that access is bounded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically, this "unsafe" is here out of caution. This function should not be marked unsafe, because the returned object has the correct lifetime and prevents mutating the model for as long as it lives.
The problem here is that tensors inside ggml
can be mutated through a shared ref &
. When writing the function to return the slice I realized all this and thought: "Let's just mark this as unsafe for the time being until we figure a better way".
The whole "safety-ness" of the library needs to be reviewed, but I'd do that in another PR or we'll end up putting too much stuff on this one. For now, let's just push the requirements to the caller on this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough, let's revisit later!
llama-rs/src/lib.rs
Outdated
|
||
/// Clears the model's state, so inference can be restarted without having | ||
/// to load the weights again. | ||
pub fn clear_state(&mut self) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a huge fan of this - it seems footgun-y (like it'd be easy to forget to clear the state after you're done). I'd suggest decoupling inference from the model itself, so that you have something like
let mut session: InferenceSession = model.infer();
session.feed_prompt("blah");
where all of the state relevant to this inference is in InferenceSession
, and the model is purely data. (Bonus points if you have it borrow from Model
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmmm, you're right 🤔 This requires moving the memory_k and memory_v tensors into another struct though. They should probably be owned by an entirely separate ggml context stored inside the InferenceSession
to make things easier.
llama-rs/src/lib.rs
Outdated
|
||
if eot { | ||
while self.n_past < self.hparams.n_ctx as usize { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop seems a little weird to me. Maybe have a comment explaining what it's doing for people who look at it to figure out how to structure their own loops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is to continue inferring tokens until the model returns an EndOfText, or we run out of space. Whatever happens first. I added an explanatory comment 👍
great job! |
This was simpler than I expected, so time to sneak in a little PR :)
This implements the same feature described in ggerganov/llama.cpp#64
Basically, this adds an API to access the memory tensors in the model, and adds a couple functions to save them and load them from disk.
I also added
--cache-prompt
and--restore-prompt
flags forllama-cli
:--cache-prompt
mode, no inference is run, and the program dumps the memory contents into a file after feeding it the prompt.--restore-prompt
, the contents of the memory are read from the disk at the given path, and the prompt you feed the system is concatenated, right after the cached prompt.Note this PR builds on top of #10, so the diff will not look correct until that is merged. You have to filter to only look at the commits from this PR when you go to the "Files changed" tab.
I have tested the changes, but since this adds a non-trivial change, I'd like to make sure I didn't break anything before merging. Anyone interested please pull the branch and test it before we commit. Here's some minimal instructions.
--restore-cache
flag and a smaller promptThe observed behavior should be that the second time, the system starts predicting just where it left off, and you only pay for the time it takes to parse the extra prompt :)