-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix KV cache #1364
base: main
Are you sure you want to change the base?
Fix KV cache #1364
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1364
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The kv_cache sequence length is set at max sequence length so that the memory is allocated once and then the future states are masked out during decoding. This change would temporarily save some memory but would slow down decoding. An optimal balance would probably a dynamic allocation scheme where the cache size started small and was doubled whenever seq_len == cache_size until cache_size == max_seq_len. |
the cache can be preallocated, but an iterative calculation of attention of Q to the max_seq tokens (K) with later masking and MM with max_seq V is not efficient |
@stsouko thanks for this PR! Let me repeat it back to you to see if i get this right: Lets say we have max_seq_len=128k, and we are generating position_id = 3. Our useful cache_size = 2. However, as it is, we would do the attention using the whole 128k. But you are proposing that we should slice it and just use the 2 non-zero tokens in the cache. Is that right? What I dont know:
Do you think you could run some quick benchmark to check tokens per second + peak_allocated_vram? It would be a very easy approval. Thanks again for pointing it out! :) |
Sure, Before: INFO:torchtune.utils.logging:Time for inference: 40.26 sec total, 7.45 tokens/sec After: INFO:torchtune.utils.logging:Time for inference: 10.32 sec total, 29.06 tokens/sec P.S. joke:
|
wow, these results are great @stsoukoI ! A couple of comments:
|
Thank you so much for generating these numbers. I still have some concerns whether the dynamic shapes cause issues for torch.compile that might cause generation to be slower when compiling. I'm willing to let it land like this though and leave that as future work. Please let me know if you're willing to get numbers for the compiled model too, otherwise I'll land this now and open an issue to check this. If you decide to test compile, you'd have to make a small change (to fix a known bug) and remove the |
I've been working on this in parallel - needed to do a warmup run for compile. edit: compile unhappy torch._dynamo.exc.Unsupported: Dynamic slicing on data-dependent value is not supported
File "/home/salman/torchtune/torchtune/modules/kv_cache.py", line 79, in update
return k_out[:, :, :size], v_out[:, :, :size] |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1364 +/- ##
==========================================
- Coverage 27.48% 27.42% -0.07%
==========================================
Files 272 272
Lines 12888 12917 +29
==========================================
Hits 3542 3542
- Misses 9346 9375 +29 ☔ View full report in Codecov by Sentry. |
We can probably put a flag that its dynamic. I also wonder how unhappy compile would be if we just 2x expand the cache as necessary using the expand_to_power_of_two fn If you try the dynamic, I believe its something like this inside of the code:
wild guess 1: Maybe it would help if we defined "sliced_k_out" instead of just returning k_out[:idx]? |
@@ -40,7 +40,6 @@ def __init__( | |||
self.register_buffer( | |||
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False | |||
) | |||
self.size = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you remove this variable, you need to update MultiHeadAttention too because it uses self.kvcach.size
I think not very happy particularly if we're doing edit: cpu cache offloading seems to be a thing too |
@SalmanMohammadi , my understanding is that we dont focus too much on generation, since we just it for eval/testing models |
Context
KV cache designed for decoders, thus, should contain only prior and current embeddings. not future.
Please link to any issues this PR addresses.
Changelog
Fix
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models