Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kvcache, async eval, etc for #93 #109

Merged
merged 2 commits into from
Aug 29, 2024
Merged

add kvcache, async eval, etc for #93 #109

merged 2 commits into from
Aug 29, 2024

Conversation

davidkoski
Copy link
Collaborator

@davidkoski davidkoski commented Aug 23, 2024

  • sampling code compiled
  • KVCache
  • async_eval
  • NaiveStreamingDetokenizer

My starting point with mlx v0.16.0:

./mlx-run llm-tool --model mlx-community/Phi-3-mini-4k-instruct-4bit -m 100000 --temperature 0 -p @/tmp/p.txt --cache-size 500 --temperature 0

Prompt is:

<|system|>
You are a helpful assistant.<|end|>
<|user|>
How to explain Internet for a medieval knight?<|end|>
<|assistant|>

Here are the initial performance numbers:

python Generation: 673 tokens, 114.404 tokens-per-sec, 5.873911874999999s
swift Generation: 673 tokens, 83.136954 tokens/s

Looking at the profiler we can see:

Task Python Swift
Wall Time 9.6 12.1
CPU Time 10.94 9.91
Metal Submit 1.75 3.83
Metal didComplete 0.48 0.83
mlx::core:scheduler 2.92 3.63
blas_thread_server 2.9 -
main 1.67 0.74

I am not sure what the blas_thread_server is, but it is soaking up a bunch of CPU time — it looks like it is running async to the main computation, so other than efficiency isn’t affecting performance. If we subtract the blas_thread_server from the CPU time it is clear the python code is doing less work.

I found the following:

  • The sampling code in python is compiled — this made a pretty big difference
  • The KVCache also made a big difference
  • The async_eval in the python code helped a little
  • The NaiveStreamingDetokenizer helped a little

The prompt splitting didn’t help here as the default split size is 512 and that is larger than the prompt. The output in this case is only 673 tokens, so I don’t think that is particularly long, so some of these may not be getting the full benefit.

With these in place the performance is getting closer:

python Generation: 673 tokens, 114.404 tokens-per-sec, 5.873911874999999s
swift 673 tokens, 97.718867 tokens/s, 6.887104s
Task Python Swift
Wall Time 9.6 9.4
CPU Time 10.94 7.6
Metal Submit 1.75 1.21
Metal didComplete 0.48 0.28
mlx::core:scheduler 2.92 1.99
blas_thread_server 2.9 -
main 1.67 0.69

We have a big win in CPU time and even total wall time (startup time I presume) but are still nearly a second slower on the generation, so there is still some more work needed.

Part 2!

Here is the output from swift -- for comparison I had 114 tokens/s from python (though it was usually lower than that and I found the number to be noisy).

Prompt:     21 tokens, 212.221945 tokens/s
Generation: 673 tokens, 114.662957 tokens/s, 5.869376s

Here is output from Instruments showing the synchronous evaluation (via item()), the async_eval() calls and the GPU activity in the python program:

image

Notice that the GPU stays active most of the time.

Here is a similar view of the swift program (at around 97 tokens/s) from part 1:

image

The pattern is similar, but there are ~1.5ms gaps where the GPU goes idle -- the scheduler didn't have any work to do.

The final version looks like this:

image

The GPU is kept busy and we are hitting our performance targets!

The final changes resolved an issue with the asyncEval() and item()/eval() calls that were causing item() to wait for the completion of the asyncEval() -- see ml-explore/mlx-swift#127.

@davidkoski davidkoski requested a review from awni August 23, 2024 23:13
@davidkoski davidkoski marked this pull request as draft August 23, 2024 23:14
if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
(keys, values) = cache.update(keys: keys, values: values)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of very mechanical changes like this to support the new KVCache. The KVCache is a reference type so we pass it in and mutate it rather than pass in/pass out.


public let vocabularySize: Int
public let kvHeads: [Int]
public let headDim: IntOrPair
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support for creating the KVCache

out = matmul(out, model.embedTokens.weight.T)
out = out * self.logitScale
return (out, cache)
return out
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every model is changed in this way -- they aren't all identical changes, but 70% are. The ones that are not identical are just different because of the specifics in the attention layers.

@@ -5,56 +5,12 @@ import MLX
import MLXRandom
import Tokenizers

private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions move inside the SampleContext for encapsulation. They also get compiled to match the python side.

}

/// Encapsulaton of the repetitionPenalty
struct RepetitionContext: Sendable {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored out the repetition context code for encapsulation -- I found that this was missing the token append line.


// compute the next state and async eval the next token
y = step(previous: previousY)
asyncEval(y)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This restructuring supports async eval now. We generate the graph for the next token and start the eval while returning the previously generated token.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For comparison here is the python version:

    while True:
        next_y, next_logprobs = _step(y)
        mx.async_eval(next_y)
        yield y.item(), logprobs
        y, logprobs = next_y, next_logprobs

This is iterator style so we are not yielding, so I restructured it a bit to keep it clear.

Not for this PR, but now that we can easily (memory-wise) evaluate off the main thread we could change this to an AsyncStream, which does have yield. IIRC there were some issues around cleanly shutting down, but I think those can be resolved.

Prompt Tokens per second: \(promptTokensPerSecond.formatted())
Generation tokens per second: \(tokensPerSecond.formatted())
Prompt: \(promptTokens.count) tokens, \(promptTokensPerSecond.formatted()) tokens/s
Generation: \(tokens.count) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the summary stats to match the current python code

import Foundation
import MLX

public class KVCache: Evaluatable {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The KVCache. I will add the rotating cache as well, but the default max_kv_size is None so we only need this for the performance tuning.


}

public struct NaiveStreamingDetokenizer: StreamingDetokenizer {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new NaiveStreamingDetokenizer

if let new = detokenizer.next() {
print(new, terminator: "")
fflush(stdout)
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the new detokenizer

@awni
Copy link
Member

awni commented Aug 24, 2024

With these in place the performance is getting closer:

python Generation: 673 tokens, 114.404 tokens-per-sec, 5.873911874999999s
swift 673 tokens, 97.718867 tokens/s, 6.887104s

That is very strange.. all of the fixes you pulled in should have gotten us to basically parity with Python but it's still substantially slower 🤔

Could be some perf improvements in the updated MLX (17.0) as well, not sure what you were running.

@awni
Copy link
Member

awni commented Aug 24, 2024

I am not sure what the blas_thread_server is, but it is soaking up a bunch of CPU time

That's not an MLX thing.. it looks like it's coming from openblas which is probably from NumPy if I had to guess. Very odd though .. its possibly dependent on the numpy version / back-end

@awni
Copy link
Member

awni commented Aug 24, 2024

@davidkoski any reason this is still in draft? It looks more or less ready for review/merge?

@davidkoski
Copy link
Collaborator Author

@davidkoski any reason this is still in draft? It looks more or less ready for review/merge?

Still needs:

  • some documentation
  • another KVCache implementation or at least setting it up as a protocol so it can be used that way

I wanted to figure out the rest of the performance gap but I think this is useful enough that we could take it.

Could be some perf improvements in the updated MLX (17.0) as well, not sure what you were running.

I was running 16.0 for both.

One thing I will check is the implementation of the LLM model code -- maybe that has diverged.

@davidkoski davidkoski marked this pull request as ready for review August 27, 2024 22:54
@davidkoski
Copy link
Collaborator Author

@awni this is ready for review. It requires the changes from ml-explore/mlx-swift#127 so I will have one more commit after I merge and tag that one.

@awni
Copy link
Member

awni commented Aug 27, 2024

Awesome!! Will do so asap.

@@ -78,54 +34,188 @@ public struct GenerateParameters: Sendable {
}
}

struct SampleContext {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Encapsulation of the sampling

@@ -195,20 +285,17 @@ public func generate(
{
// compute the timing for the prompt
if tokens.isEmpty {
eval(token)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to return Int -- returning the MLXArray is sketchy at best since it had not been evaluated. Have to be very careful with thread. Anyway, this is internal to the generate() function.

- sampling code compiled
- KVCache
- async_eval
- NaiveStreamingDetokenizer
- use mlx-swift 16.0.1
let newSegment = tokenizer.decode(tokens: segmentTokens)
let new = newSegment.suffix(newSegment.count - segment.count)

if new.contains("\n") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that be ends with?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, and perhaps that is a bug with the python detokenizer. Adding a token will go from foo to foo\n\nbar

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... but in that case it might not be safe to split the stream after \n\nbar? The implicit assumption is that whenever the text ends in \n the detokenization is always the same whether it is done on the full string or split into two chunks at \n.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me get an example -- maybe we need to discuss this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first paragraph in the example it produces this:

tokens: [1954, 22094, 29892, 9425, 1006, 2029, 3406, 29892, 393, 12595, 14973, 1063, 8608, 287, 1549, 278, 286, 2879, 310, 931, 304, 385, 21502, 305, 988, 521, 2561, 719, 20913, 29879, 322, 3514, 332, 1794, 696, 314, 278, 12625, 310, 1238, 566, 284, 5546, 29889, 2567, 29892, 7623, 393, 266, 457, 10435, 322, 5075, 272, 505, 1063, 8611, 411, 263, 16624, 936, 25792, 29892, 20794, 14904, 304, 263, 1855, 29885, 310, 1095, 2222, 24496, 322, 7134, 29889, 512, 445, 1855, 29885, 29892, 2012, 310, 8121, 303, 292, 363, 10657, 322, 20123, 29892, 13874, 21126, 338, 304, 16508, 29714, 322, 3957, 4822, 21188, 12625, 29889, 29871, 13, 13]

string: "Imagine, dear interlocutor, that thou hast been transported through the mists of time to an epoch where chivalry reigns and samurai roam the lands of feudal Japan. Now, picture that thine horse and armor have been replaced with a mystical portal, bringing thee to a realm of endless possibilities and knowledge. In this realm, instead of jousting for honor and territory, thy quest is to seek wisdom and connection across distant lands. "

The next token is \n\nThis:

tokens: [1954, 22094, 29892, 9425, 1006, 2029, 3406, 29892, 393, 12595, 14973, 1063, 8608, 287, 1549, 278, 286, 2879, 310, 931, 304, 385, 21502, 305, 988, 521, 2561, 719, 20913, 29879, 322, 3514, 332, 1794, 696, 314, 278, 12625, 310, 1238, 566, 284, 5546, 29889, 2567, 29892, 7623, 393, 266, 457, 10435, 322, 5075, 272, 505, 1063, 8611, 411, 263, 16624, 936, 25792, 29892, 20794, 14904, 304, 263, 1855, 29885, 310, 1095, 2222, 24496, 322, 7134, 29889, 512, 445, 1855, 29885, 29892, 2012, 310, 8121, 303, 292, 363, 10657, 322, 20123, 29892, 13874, 21126, 338, 304, 16508, 29714, 322, 3957, 4822, 21188, 12625, 29889, 29871, 13, 13, 4013]

string: "Imagine, dear interlocutor, that thou hast been transported through the mists of time to an epoch where chivalry reigns and samurai roam the lands of feudal Japan. Now, picture that thine horse and armor have been replaced with a mystical portal, bringing thee to a realm of endless possibilities and knowledge. In this realm, instead of jousting for honor and territory, thy quest is to seek wisdom and connection across distant lands.

This"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next paragraph starts like this:

tokens: [4013, 1855]
string: "This real"

tokens: [4013, 1855, 29885]
string: "This realm"

So it is the sequence of tokens that produces the \n\n, not the 4013 itself.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline discussion:

  • python is seeing newlines at the end of the string (I thought it wasn't but I can't be sure what I was looking at now)
  • swift (different tokenizer implementation) sees newlines at the end with some models and not others
  • we will use hasSuffix() to match the python behavior

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome!! 🚀

Just one comment, lmk what you think. O/w good to go!

@davidkoski davidkoski merged commit ab94ffc into main Aug 29, 2024
1 check passed
@davidkoski davidkoski deleted the kvcache branch August 29, 2024 20:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants