Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: StaticCache & inputs_embeds #32932

Merged

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Aug 22, 2024

What does this PR do?

Fixes #32911. Enables generation with Static Cache and inputs embeds, previously it was failing due to incorrect calculation of max_cache_length

Added a test for that and added tests for Gemma2ForCausalLM. Some things to note:

  • Gemma2 doesn't support StaticCache. It can with some small changes but imo we shouldn't
  • Static shape cache classes have no support for contrastive search, dola, low-memory generation and assisted decoding. So these tests are all skipped in Gemma2. I think if we want to enable the, it should go on another PR for upgrading static cache classes

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, to nits but good otherwise. Do we take the max of num beams, num return sequences because they stem from beams?

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking care of gemma 2 🤗

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Outdated Show resolved Hide resolved
@@ -59,7 +60,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
if is_torch_available()
else ()
)
all_generative_model_classes = ()
all_generative_model_classes = (Gemma2ForCausalLM,) if is_torch_available() else ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😱 good spot!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed because it was faiiling too many tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I skipped those that shouldn't be triggered due to model-specific cache and fixed other failing ones

def test_generate_from_inputs_embeds_with_static_cache(self):
pass

def _check_attentions_for_generate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the reason for the overwrite at the top of the fn as a comment, here an on the other functions that need an overwrite! That way, we immediately know why the function needs to exist :)

(I see that you added a few comments below, like HybridCache has fixed length for key/values, moving it to the top suffices)

@zucchini-nlp zucchini-nlp force-pushed the embeds-with-static-cache branch from 4dd1494 to fce9e7e Compare August 30, 2024 17:33
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating 💛

@zzxslp
Copy link

zzxslp commented Sep 6, 2024

Hi, run into similar errors as in #32911, will this PR get merged?

@zucchini-nlp
Copy link
Member Author

Yes, merging now, should be ready

@zucchini-nlp zucchini-nlp merged commit 1759bb9 into huggingface:main Sep 6, 2024
23 checks passed
zucchini-nlp added a commit to zucchini-nlp/transformers that referenced this pull request Sep 6, 2024
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error in _prepare_generated_length
5 participants