Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core generation] Adds support for static KV cache #27931

Merged
merged 121 commits into from
Feb 8, 2024
Merged

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Dec 10, 2023

~4x speedups with cuda graphs! 🥳

Currently getting ~4x speedups compare to dynamic cache with torch.compile for a single forward pass (agnostic to batch but faster for smaller batch)

Forward is very very very fast, but materializing the input costs a bit!
~10ms / forward is what we get to!

  • Refactors the way we deal with attention mask:
    • causal and padding are separated
    • does not rely on the past_key_values
    • merged in 2 line. No attention mask utils are needed, no extra complicated logic all explicit
    • LlamaAttention is not self contained, this added 20% overhead in a simple forward
    • Gets rid of the entire mask_attn_utils 😄
  • Save the cache class in the generation config
  • Init the cache with the batch size (from the generate call) and the max_length from the generation config (taking max_new_tokens) into account
  • torch.compile

Benchmark using af097af

Use it in generate:

Use this: EDIT: TO COME

Failing test left for @gante

Related to the fact that I don't return past_key_values / is None so the test_new_cache_format fails. I don't want to dive in this.

fixes #28075 , fixes #28610, fixes #28190

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker changed the title [Core genration] Adds support for static KV cache [Core generation] Adds support for static KV cache Dec 12, 2023
@oobabooga
Copy link
Contributor

If I understand correctly, this PR should close the existing gap between inference with transformers + AutoGPTQ and inference with ExLlama, as the VRAM usage would become much more controlled. I'm rooting for it :)

@ArthurZucker
Copy link
Collaborator Author

Thanks! 🤗

@patrickvonplaten
Copy link
Contributor

Exciting PR!

@xkszltl
Copy link
Contributor

xkszltl commented Feb 26, 2024

Could you help clarify the removal of comment?
image
It's probably removed by mistake as attn_weights is still not supported in FA2.
We should also add a warning before setting it to False, just like the SDPA counterpart.

@paulcx
Copy link

paulcx commented Mar 6, 2024

Hi @ArthurZucker It seems that the increase in VRAM could potentially lead to out of memory (OOM) comment1 comment2, as pointed out in this PR by @danielhanchen

It seems like a change was made in another PR which allocates a causal mask of size (16384, 16384) https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L940

The triu causes the causal mask to upcast to float32, using 16384^2 * 4bytes = 1GB of extra VRAM. We have n^2 * 4 / 1024 / 1024 = 37.25GB in your screenshot, so I'm assuming you're also doing RoPE Scaling to 100K context length? So ie a (100K, 100K) matrix was trying to be created.

Could you please take a look into it?

@gante
Copy link
Member

gante commented Mar 6, 2024

@paulcx your issue is related to this one (#29484) -- let's keep the discussion there! :)

@ArthurZucker
Copy link
Collaborator Author

Yes! And @paulcx I'm sorry this broke for you

@fxmarty
Copy link
Contributor

fxmarty commented Mar 18, 2024

See #29241 which alleviates but does not fix the issue @paulcx

@aliencaocao
Copy link
Contributor

Does this work for llava? from my testing it doesnt

@paulcx
Copy link

paulcx commented Mar 19, 2024

See #29241 which alleviates but does not fix the issue @paulcx

does this temporary fix work for 200K?

@ArthurZucker
Copy link
Collaborator Author

No, you have to update the max_position_embedding. It is allocating 200k because you set it to 200K while you machine does not support 200K input

@ArthurZucker
Copy link
Collaborator Author

ArthurZucker commented Mar 19, 2024

max_position_embedding is only use for the causal_mask now, while previously it was used for sin and cos. In both cases it was to cache the maximum number of positions that will be passed to the model. If you have 200K context length, that does not mean you can do inference / training with it!

We can also just remove it, but then you need to allocate the causal_mask at each forward pass

@paulcx
Copy link

paulcx commented Mar 20, 2024

@ArthurZucker If I understand correctly, my thought is to lower the max_position_embedding, such as 4096, because it only affects the initial position embedding and caching for 200K. But during inference, lengths approaching 200K will still be calculated, just slower. This workaround, it can ensure normal training and inference for non-200K cases. Is my understanding correct?

@ArthurZucker
Copy link
Collaborator Author

ArthurZucker commented Mar 21, 2024

This was fixed by #29753 ! Sorry @paulcx for the inconvenience. For static / compile cache you should still reduce the max position embedding or it will OOM 😉

@paulcx
Copy link

paulcx commented Mar 21, 2024

static / compile cache

Thank you and great work on new release @ArthurZucker.

Would you mind clarifying the use case of "static / compile cache" in release note? I'm not sure if I understand correctly.

@ArthurZucker
Copy link
Collaborator Author

It is mostly this: https://gist.github.com/ArthurZucker/af34221def212259b43d55a2811d2dbb, you can get x4 generation speed in transformers with torch compile and static cache!

@aliencaocao
Copy link
Contributor

It is mostly this: https://gist.github.com/ArthurZucker/af34221def212259b43d55a2811d2dbb, you can get x4 generation speed in transformers with torch compile and static cache!

Is this expected to work with llava-next?

@ArthurZucker
Copy link
Collaborator Author

I believe so yes, if not we can add support for it

@ArthurZucker
Copy link
Collaborator Author

Feel free to open an issue if it doesnt work

@aliencaocao
Copy link
Contributor

I have tried and it dont work because the vision tower changes the shape of inputs after encoding to patches. Also, it doesnt work for bnb 4 bits

@ArthurZucker
Copy link
Collaborator Author

bnb is a different issue, torch.compile might not support this (int8 yes).
For the encoder part cc @NielsRogge could be nice

@aliencaocao
Copy link
Contributor

we are using NF4 for bnb

@nxphi47
Copy link
Contributor

nxphi47 commented Mar 28, 2024

It is mostly this: https://gist.github.com/ArthurZucker/af34221def212259b43d55a2811d2dbb, you can get x4 generation speed in transformers with torch compile and static cache!

@ArthurZucker Just checking have you added this to model.generate, or we still have to follow your script there to use static KV cache?

@ArthurZucker
Copy link
Collaborator Author

@gante is working on this here #29374

itazap pushed a commit that referenced this pull request May 14, 2024
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet