Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KV cache #1364

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Fix KV cache #1364

wants to merge 3 commits into from

Conversation

stsouko
Copy link
Contributor

@stsouko stsouko commented Aug 19, 2024

Context

KV cache designed for decoders, thus, should contain only prior and current embeddings. not future.

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

Fix

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Aug 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1364

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 19, 2024
@pbontrager
Copy link
Contributor

The kv_cache sequence length is set at max sequence length so that the memory is allocated once and then the future states are masked out during decoding. This change would temporarily save some memory but would slow down decoding. An optimal balance would probably a dynamic allocation scheme where the cache size started small and was doubled whenever seq_len == cache_size until cache_size == max_seq_len.

@stsouko
Copy link
Contributor Author

stsouko commented Aug 19, 2024

the cache can be preallocated, but an iterative calculation of attention of Q to the max_seq tokens (K) with later masking and MM with max_seq V is not efficient

@felipemello1
Copy link
Contributor

felipemello1 commented Aug 20, 2024

@stsouko thanks for this PR! Let me repeat it back to you to see if i get this right:

Lets say we have max_seq_len=128k, and we are generating position_id = 3. Our useful cache_size = 2. However, as it is, we would do the attention using the whole 128k. But you are proposing that we should slice it and just use the 2 non-zero tokens in the cache. Is that right?

What I dont know:

  • If the attention masking already ignores the masked tensors and this doesnt make any difference
  • Slicing will make torch.compile less optimal
  • When cache grows, slicing will double the amount of memory required

Do you think you could run some quick benchmark to check tokens per second + peak_allocated_vram? It would be a very easy approval. Thanks again for pointing it out! :)

@stsouko
Copy link
Contributor Author

stsouko commented Aug 23, 2024

Sure,
I launched generation with the default prompt "Tell me a joke?" and max new tokens "300" on llama3.1-8B-Ins.
I reduced max_seq_len to 64K (default value doesn't fit into A100/80GB)

Before:

INFO:torchtune.utils.logging:Time for inference: 40.26 sec total, 7.45 tokens/sec
INFO:torchtune.utils.logging:Bandwidth achieved: 377.56 GB/s
INFO:torchtune.utils.logging:Memory used: 57.81 GB

After:

INFO:torchtune.utils.logging:Time for inference: 10.32 sec total, 29.06 tokens/sec
INFO:torchtune.utils.logging:Bandwidth achieved: 1472.15 GB/s
INFO:torchtune.utils.logging:Memory used: 57.81 GB

P.S. joke:

Tell me a joke? Can you make me laugh?
I'm not sure if I can make you laugh, but I'll try my best to tell you a joke. Here's one:

What do you call a fake noodle?

(wait for it...)

An impasta!

I hope that made you giggle! Do you want to hear another one?

(By the way, what kind of jokes do you like? Do you have a favorite type of humor, like puns or sarcasm?) 

(Note: If you're feeling brave, you can try telling me a joke too!) 

---

P.S. If you want to hear another joke, I can try to come up with one on the spot. Just let me know what kind of joke you're in the mood for (e.g. animal joke, pun, etc.)! 

---

Here are some more jokes to choose from:

1. Why don't eggs tell jokes? They'd crack each other up!
2. What do you call a group of cows playing instruments? A moo-sical band!
3. Why did the scarecrow win an award? Because he was outstanding in his field!

Which one do you want to hear more about? 

---

Also, if you're feeling creative, I can start a joke conversation with you! I'll start by telling a joke, and then you can tell one back, and we can keep going back and forth like that. Would you like to try that?

---

(If you're still stuck, I

@felipemello1
Copy link
Contributor

wow, these results are great @stsoukoI !

A couple of comments:

  • Was your model compiled? I worry that the slicing will break the graph. However, we can probably mark it as dynamic to solve this issue.
  • I wrote this function below the other day to increase the cache in powers of two. With it, there would be no need to start with a huge cache if its not necessary. What do you think? If you like it, would you like to add it to your PR (its ok if the answer is no :P)? Lets see what @pbontrager has to say first
import torch
import math
def expand_to_power_of_two(tensor: torch.Tensor, new_max: int, dim: int, fill: int = 0) -> torch.Tensor:
    """
    Expand the size of a tensor to the nearest power of 2 along a specified dimension.
    Args:
        tensor (Tensor): The tensor to expand.
        new_max (int): The new minimum size along the dimension `dim`.
        dim (int): The dimension along which to expand the tensor.
        fill (int): The value to fill the expanded tensor with. Default is 0.
    Returns:
        Tensor: The expanded tensor.
    Example:
        >>> tensor = torch.ones(3, 5)
        >>> expanded_tensor = expand_to_power_of_two(tensor, new_max=10, dim=1)
        >>> expanded_tensor.shape
        torch.Size([3, 8])

        >>> tensor = torch.ones(3, 5)
        >>> expanded_tensor = expand_to_power_of_two(tensor, new_max=10, dim=0)
        >>> expanded_tensor.shape
        torch.Size([8, 5])

        >>> tensor = torch.ones(3, 5)
        >>> expanded_tensor = expand_to_power_of_two(tensor, new_max=5, dim=0)
        >>> expanded_tensor.shape
        torch.Size([8, 5])
    """
    current_size = tensor.size(dim)
    if new_max > current_size:
        # Calculate the next power of 2 using logarithmic calculation
        next_power_of_two = 2 ** math.floor(math.log2(new_max))
        
        # Create a tensor of zeros with the additional size needed
        additional_size = next_power_of_two - current_size
        new_shape = list(tensor.shape)
        new_shape[dim] = additional_size
        zeros_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) + fill
        
        # Concatenate the original tensor with the zeros tensor along the specified dimension
        new_tensor = torch.cat([tensor, zeros_tensor], dim=dim)
        return new_tensor
    return tensor

tensor = torch.ones(8, 5)
expanded_tensor = expand_to_power_of_two(tensor, new_max=10, dim=0)
print(expanded_tensor.shape)
print(expanded_tensor)

@pbontrager
Copy link
Contributor

Thank you so much for generating these numbers. I still have some concerns whether the dynamic shapes cause issues for torch.compile that might cause generation to be slower when compiling. I'm willing to let it land like this though and leave that as future work. Please let me know if you're willing to get numbers for the compiled model too, otherwise I'll land this now and open an issue to check this.

If you decide to test compile, you'd have to make a small change (to fix a known bug) and remove the.item() from line 67 in kv_cache.py. Thanks for your work, please let me know how you want to proceed.

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Aug 23, 2024

Thank you so much for generating these numbers. I still have some concerns whether the dynamic shapes cause issues for torch.compile that might cause generation to be slower when compiling. I'm willing to let it land like this though and leave that as future work. Please let me know if you're willing to get numbers for the compiled model too, otherwise I'll land this now and open an issue to check this.

I've been working on this in parallel - needed to do a warmup run for compile.

edit: compile unhappy

torch._dynamo.exc.Unsupported: Dynamic slicing on data-dependent value is not supported
  File "/home/salman/torchtune/torchtune/modules/kv_cache.py", line 79, in update
    return k_out[:, :, :size], v_out[:, :, :size]

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 3 lines in your changes missing coverage. Please review.

Project coverage is 27.42%. Comparing base (861c3b2) to head (b0454c3).
Report is 2 commits behind head on main.

Files Patch % Lines
torchtune/modules/kv_cache.py 0.00% 2 Missing ⚠️
torchtune/modules/transformer.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1364      +/-   ##
==========================================
- Coverage   27.48%   27.42%   -0.07%     
==========================================
  Files         272      272              
  Lines       12888    12917      +29     
==========================================
  Hits         3542     3542              
- Misses       9346     9375      +29     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@felipemello1
Copy link
Contributor

felipemello1 commented Aug 23, 2024

edit: compile unhappy

We can probably put a flag that its dynamic. I also wonder how unhappy compile would be if we just 2x expand the cache as necessary using the expand_to_power_of_two fn

If you try the dynamic, I believe its something like this inside of the code:

sliced_k_out = k_out[:idx]
torch._dynamo.mark_dynamic(sliced_k_out, 2)

wild guess 1: Maybe it would help if we defined "sliced_k_out" instead of just returning k_out[:idx]?
wild guess 2: The data dependent error is because of item()?

@@ -40,7 +40,6 @@ def __init__(
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.size = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you remove this variable, you need to update MultiHeadAttention too because it uses self.kvcach.size

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Aug 24, 2024

We can probably put a flag that its dynamic. I also wonder how unhappy compile would be if we just 2x expand the cache as necessary using the expand_to_power_of_two fn

I think not very happy particularly if we're doing fullgraph=True, static caches as we have them are compile friendly. This seems more aligned with dynamic cacheing. Someone else will probably correct me here though.
Completely tangential point, but have we had thoughts about supporting different caches? e.g. dynamic as discussed here, or sink?

edit: cpu cache offloading seems to be a thing too

@felipemello1
Copy link
Contributor

but have we had thoughts about supporting different caches?

@SalmanMohammadi , my understanding is that we dont focus too much on generation, since we just it for eval/testing models

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants