-
Notifications
You must be signed in to change notification settings - Fork 167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add kvcache, async eval, etc for #93 #109
Conversation
if let cache { | ||
queries = rope(queries, offset: cache.offset) | ||
keys = rope(keys, offset: cache.offset) | ||
(keys, values) = cache.update(keys: keys, values: values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a lot of very mechanical changes like this to support the new KVCache. The KVCache is a reference type so we pass it in and mutate it rather than pass in/pass out.
|
||
public let vocabularySize: Int | ||
public let kvHeads: [Int] | ||
public let headDim: IntOrPair |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support for creating the KVCache
out = matmul(out, model.embedTokens.weight.T) | ||
out = out * self.logitScale | ||
return (out, cache) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every model is changed in this way -- they aren't all identical changes, but 70% are. The ones that are not identical are just different because of the specifics in the attention layers.
@@ -5,56 +5,12 @@ import MLX | |||
import MLXRandom | |||
import Tokenizers | |||
|
|||
private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These functions move inside the SampleContext for encapsulation. They also get compiled to match the python side.
} | ||
|
||
/// Encapsulaton of the repetitionPenalty | ||
struct RepetitionContext: Sendable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored out the repetition context code for encapsulation -- I found that this was missing the token append line.
|
||
// compute the next state and async eval the next token | ||
y = step(previous: previousY) | ||
asyncEval(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This restructuring supports async eval now. We generate the graph for the next token and start the eval while returning the previously generated token.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For comparison here is the python version:
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs
This is iterator style so we are not yielding, so I restructured it a bit to keep it clear.
Not for this PR, but now that we can easily (memory-wise) evaluate off the main thread we could change this to an AsyncStream
, which does have yield
. IIRC there were some issues around cleanly shutting down, but I think those can be resolved.
Prompt Tokens per second: \(promptTokensPerSecond.formatted()) | ||
Generation tokens per second: \(tokensPerSecond.formatted()) | ||
Prompt: \(promptTokens.count) tokens, \(promptTokensPerSecond.formatted()) tokens/s | ||
Generation: \(tokens.count) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the summary stats to match the current python code
Libraries/LLM/KVCache.swift
Outdated
import Foundation | ||
import MLX | ||
|
||
public class KVCache: Evaluatable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The KVCache. I will add the rotating cache as well, but the default max_kv_size
is None
so we only need this for the performance tuning.
|
||
} | ||
|
||
public struct NaiveStreamingDetokenizer: StreamingDetokenizer { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new NaiveStreamingDetokenizer
if let new = detokenizer.next() { | ||
print(new, terminator: "") | ||
fflush(stdout) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the new detokenizer
That is very strange.. all of the fixes you pulled in should have gotten us to basically parity with Python but it's still substantially slower 🤔 Could be some perf improvements in the updated MLX (17.0) as well, not sure what you were running. |
That's not an MLX thing.. it looks like it's coming from openblas which is probably from NumPy if I had to guess. Very odd though .. its possibly dependent on the numpy version / back-end |
@davidkoski any reason this is still in draft? It looks more or less ready for review/merge? |
Still needs:
I wanted to figure out the rest of the performance gap but I think this is useful enough that we could take it.
I was running 16.0 for both. One thing I will check is the implementation of the LLM model code -- maybe that has diverged. |
@awni this is ready for review. It requires the changes from ml-explore/mlx-swift#127 so I will have one more commit after I merge and tag that one. |
Awesome!! Will do so asap. |
@@ -78,54 +34,188 @@ public struct GenerateParameters: Sendable { | |||
} | |||
} | |||
|
|||
struct SampleContext { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Encapsulation of the sampling
@@ -195,20 +285,17 @@ public func generate( | |||
{ | |||
// compute the timing for the prompt | |||
if tokens.isEmpty { | |||
eval(token) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to return Int -- returning the MLXArray is sketchy at best since it had not been evaluated. Have to be very careful with thread. Anyway, this is internal to the generate()
function.
- sampling code compiled - KVCache - async_eval - NaiveStreamingDetokenizer - use mlx-swift 16.0.1
Libraries/LLM/Tokenizer.swift
Outdated
let newSegment = tokenizer.decode(tokens: segmentTokens) | ||
let new = newSegment.suffix(newSegment.count - segment.count) | ||
|
||
if new.contains("\n") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't that be ends with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, and perhaps that is a bug with the python detokenizer. Adding a token will go from foo
to foo\n\nbar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... but in that case it might not be safe to split the stream after \n\nbar
? The implicit assumption is that whenever the text ends in \n
the detokenization is always the same whether it is done on the full string or split into two chunks at \n
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me get an example -- maybe we need to discuss this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the first paragraph in the example it produces this:
tokens: [1954, 22094, 29892, 9425, 1006, 2029, 3406, 29892, 393, 12595, 14973, 1063, 8608, 287, 1549, 278, 286, 2879, 310, 931, 304, 385, 21502, 305, 988, 521, 2561, 719, 20913, 29879, 322, 3514, 332, 1794, 696, 314, 278, 12625, 310, 1238, 566, 284, 5546, 29889, 2567, 29892, 7623, 393, 266, 457, 10435, 322, 5075, 272, 505, 1063, 8611, 411, 263, 16624, 936, 25792, 29892, 20794, 14904, 304, 263, 1855, 29885, 310, 1095, 2222, 24496, 322, 7134, 29889, 512, 445, 1855, 29885, 29892, 2012, 310, 8121, 303, 292, 363, 10657, 322, 20123, 29892, 13874, 21126, 338, 304, 16508, 29714, 322, 3957, 4822, 21188, 12625, 29889, 29871, 13, 13]
string: "Imagine, dear interlocutor, that thou hast been transported through the mists of time to an epoch where chivalry reigns and samurai roam the lands of feudal Japan. Now, picture that thine horse and armor have been replaced with a mystical portal, bringing thee to a realm of endless possibilities and knowledge. In this realm, instead of jousting for honor and territory, thy quest is to seek wisdom and connection across distant lands. "
The next token is \n\nThis
:
tokens: [1954, 22094, 29892, 9425, 1006, 2029, 3406, 29892, 393, 12595, 14973, 1063, 8608, 287, 1549, 278, 286, 2879, 310, 931, 304, 385, 21502, 305, 988, 521, 2561, 719, 20913, 29879, 322, 3514, 332, 1794, 696, 314, 278, 12625, 310, 1238, 566, 284, 5546, 29889, 2567, 29892, 7623, 393, 266, 457, 10435, 322, 5075, 272, 505, 1063, 8611, 411, 263, 16624, 936, 25792, 29892, 20794, 14904, 304, 263, 1855, 29885, 310, 1095, 2222, 24496, 322, 7134, 29889, 512, 445, 1855, 29885, 29892, 2012, 310, 8121, 303, 292, 363, 10657, 322, 20123, 29892, 13874, 21126, 338, 304, 16508, 29714, 322, 3957, 4822, 21188, 12625, 29889, 29871, 13, 13, 4013]
string: "Imagine, dear interlocutor, that thou hast been transported through the mists of time to an epoch where chivalry reigns and samurai roam the lands of feudal Japan. Now, picture that thine horse and armor have been replaced with a mystical portal, bringing thee to a realm of endless possibilities and knowledge. In this realm, instead of jousting for honor and territory, thy quest is to seek wisdom and connection across distant lands.
This"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The next paragraph starts like this:
tokens: [4013, 1855]
string: "This real"
tokens: [4013, 1855, 29885]
string: "This realm"
So it is the sequence of tokens that produces the \n\n
, not the 4013 itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offline discussion:
- python is seeing newlines at the end of the string (I thought it wasn't but I can't be sure what I was looking at now)
- swift (different tokenizer implementation) sees newlines at the end with some models and not others
- we will use
hasSuffix()
to match the python behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome!! 🚀
Just one comment, lmk what you think. O/w good to go!
My starting point with mlx v0.16.0:
Prompt is:
Here are the initial performance numbers:
Looking at the profiler we can see:
I am not sure what the
blas_thread_server
is, but it is soaking up a bunch of CPU time — it looks like it is running async to the main computation, so other than efficiency isn’t affecting performance. If we subtract the blas_thread_server from the CPU time it is clear the python code is doing less work.I found the following:
The prompt splitting didn’t help here as the default split size is 512 and that is larger than the prompt. The output in this case is only 673 tokens, so I don’t think that is particularly long, so some of these may not be getting the full benefit.
With these in place the performance is getting closer:
We have a big win in CPU time and even total wall time (startup time I presume) but are still nearly a second slower on the generation, so there is still some more work needed.
Part 2!
Here is the output from swift -- for comparison I had 114 tokens/s from python (though it was usually lower than that and I found the number to be noisy).
Here is output from Instruments showing the synchronous evaluation (via
item()
), theasync_eval()
calls and the GPU activity in the python program:Notice that the GPU stays active most of the time.
Here is a similar view of the swift program (at around 97 tokens/s) from part 1:
The pattern is similar, but there are ~1.5ms gaps where the GPU goes idle -- the scheduler didn't have any work to do.
The final version looks like this:
The GPU is kept busy and we are hitting our performance targets!
The final changes resolved an issue with the
asyncEval()
anditem()
/eval()
calls that were causingitem()
to wait for the completion of theasyncEval()
-- see ml-explore/mlx-swift#127.