-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: StaticCache & inputs_embeds
#32932
Fix: StaticCache & inputs_embeds
#32932
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, to nits but good otherwise. Do we take the max of num beams, num return sequences because they stem from beams?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for taking care of gemma 2 🤗
@@ -59,7 +60,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): | |||
if is_torch_available() | |||
else () | |||
) | |||
all_generative_model_classes = () | |||
all_generative_model_classes = (Gemma2ForCausalLM,) if is_torch_available() else () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😱 good spot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was removed because it was faiiling too many tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I skipped those that shouldn't be triggered due to model-specific cache and fixed other failing ones
def test_generate_from_inputs_embeds_with_static_cache(self): | ||
pass | ||
|
||
def _check_attentions_for_generate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add the reason for the overwrite at the top of the fn as a comment, here an on the other functions that need an overwrite! That way, we immediately know why the function needs to exist :)
(I see that you added a few comments below, like HybridCache has fixed length for key/values
, moving it to the top suffices)
4dd1494
to
fce9e7e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for iterating 💛
Hi, run into similar errors as in #32911, will this PR get merged? |
Yes, merging now, should be ready |
squash commit
squash commit
squash commit
squash commit
What does this PR do?
Fixes #32911. Enables generation with Static Cache and inputs embeds, previously it was failing due to incorrect calculation of
max_cache_length
Added a test for that and added tests for
Gemma2ForCausalLM
. Some things to note:StaticCache
. It can with some small changes but imo we shouldn't