Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma capping #34282

Merged
merged 58 commits into from
Nov 19, 2024
Merged

Gemma capping #34282

merged 58 commits into from
Nov 19, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 21, 2024

What does this PR do?

Adds capping for gemma2, fixes #32877

@ArthurZucker ArthurZucker marked this pull request as ready for review October 21, 2024 15:41
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of edge cases in imports which are very hard to deal with with the proposed approach. I think a simpler and more general approach is to do it the other way around:

  • dump all imports from the modular_xxx.py as is
  • dump all imports from the dependency files as is (this is currently the case)
  • Then, in the PostModularConverterCleaner, clean the imports (may even only clean the protected imports, and let ruff remove the other unused, non-protected imports)

This approach is much easier and versatile because in the Cleaner, we have access to the final source code, which is not the case when visiting the modular_xxx.py file (we only see the modular + the dependencies, and it is hard to check imports relative to only the part of the dependency files that we copy in the final file). Thus, it would ensure that all needed imports are present (i.e. we will never reach a weird edge-case when trying to match the imports as we do currently), and we can correctly remove imports that were wrongly added from the dependency files (i.e. see duplicate import in Glm due to Phi3 dependency).
This would greatly simplify the code complexity as well in my opinion.

utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved

attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output = flex_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it a bit misleading to use flex attn when we have attn_implementation="sdpa"? My concerns would be

  • People that previously used sdpa (forced or not) will suddenly have different torch requirements
  • Sdpa != Flexattn imo, it's a different API, name, and potentially slightly different behaviour
  • Are the slow tests still passing? We should ensure that it's still behaving the same ish in comparison to eager

Wdyt about making another attn implementation option for flex attn specifically? Not sure if this goes over the goal but control over the specific implementation is always appreciated.

Overall excited to see this, great work!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDPA version of gemma never "worked" TBH!
I'll probably add a new class for flex attention, this was simpler for testing

@ArthurZucker
Copy link
Collaborator Author

Okay @Cyrilvallez good point regarding cleaning! Makes more sense indeed, will update to fix 😉

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice approach! Much simpler IMO 🤗 just added some nits for clarity

utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I actually love it, I think it's much better to use different attention functions instead of different attention classes (clearer, less duplicated code, and we can easily switch between implementations even after the model has been instantiated)

src/transformers/models/gemma2/modular_gemma2.py Outdated Show resolved Hide resolved
src/transformers/models/gemma2/modular_gemma2.py Outdated Show resolved Hide resolved
src/transformers/models/gemma2/modular_gemma2.py Outdated Show resolved Hide resolved
@Cyrilvallez Cyrilvallez mentioned this pull request Nov 19, 2024
@ArthurZucker ArthurZucker merged commit 4bff54f into main Nov 19, 2024
27 checks passed
@ArthurZucker ArthurZucker deleted the gemma-capping branch November 19, 2024 12:52
@ArthurZucker ArthurZucker mentioned this pull request Nov 19, 2024
7 tasks
@vasqu vasqu mentioned this pull request Nov 23, 2024
5 tasks
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* softcapping

* soft cap before the mask

* style

* ...

* super nit

* update

* fixes

* update

* small issue with modular

* fix modular imports

* update

* fixup

* simplify a hell lot

* simplify cleaning imports

* finish fixing

* update our design

* nits

* use a deprecation cycle

* updates

* Fix modular (recursive deps need to always be computed after merges!)

* push

* fix

* update

* fix modular order

* make fix-copies

* updates

* update

* ?

* don't compile for now

* ?

* fix some stuff

* donc!

* fix copies

* update

* fixup

* ?

* fix two tests

* fix?

* for now, don't use head info

* eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :))

* fix-copies

* revert sdpa check

* Apply suggestions from code review

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* rebase, fix-copies and push

* add a slow integration test

* update the test

* fix left padding issue

* fix test

* remove duplicate scaling

* quality

* add a small test and make sure it works

* 2b

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add logit scaling sdpa using FlexAttention for Gemma2
5 participants