-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma 2 returns NaN when using default attn (sdpa) with padding #32390
Comments
Hi @chanind, thanks for reporting the issue! This is indeed a problem of scaled_dot_product_attention in PyTorch The cause of Also, a similar issue has been reported previously Besides switching to
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b", device_map="auto", torch_dtype=torch.float16
)
As suggested in the above issue, we can modify Meanwhile, we will try to fix it on our side, thanks! |
More than this, it's expected as the transformers/src/transformers/models/llama/modeling_llama.py Lines 1063 to 1072 in c1aa0ed
Which should be propagated to Gemma2. (it was not there for some reason my bad here) |
Related to #31303 |
@ArthurZucker thanks for the updated info! |
Same issue here running the code for hooking the activations of the model. Using float16 made it work. |
Hey! Make sure you are using |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
Python 3.10
Transformers 4.43.3
Linux (Colab notebook)
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The default gemma 2 2b attn results in NaN for padding tokens. A simple demo can be seen below (also reproduced in this colab notebook):
This returns the following
This can be fixed by changing the
attn_implementation
to anything exceptsdpa
Expected behavior
Using padding should not result in NaN for normal inputs to gemma 2 2b
The text was updated successfully, but these errors were encountered: