Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma 2 returns NaN when using default attn (sdpa) with padding #32390

Closed
2 of 4 tasks
chanind opened this issue Aug 2, 2024 · 10 comments
Closed
2 of 4 tasks

Gemma 2 returns NaN when using default attn (sdpa) with padding #32390

chanind opened this issue Aug 2, 2024 · 10 comments
Labels

Comments

@chanind
Copy link
Contributor

chanind commented Aug 2, 2024

System Info

Python 3.10
Transformers 4.43.3
Linux (Colab notebook)

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The default gemma 2 2b attn results in NaN for padding tokens. A simple demo can be seen below (also reproduced in this colab notebook):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

inputs = tokenizer(["Hello I am a couch", "cats"], return_tensors="pt", padding=True).to('cuda')
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

print(outputs.logits)

This returns the following

tensor([[[-24.3121,  -8.7513,  -6.9736,  ..., -18.3960, -17.4268, -24.3171],
         [-16.8873,  -4.7767,   5.8828,  ...,  -9.4981,  -9.3307, -16.7723],
         [-18.3313,   1.3191,  -4.6598,  ...,  -2.4244,   1.6774, -18.2153],
         [-18.9110,  -5.8708, -11.7827,  ...,  -5.6606,  -4.2607, -18.8535],
         [-20.1359,  -8.4194, -15.1834,  ..., -13.0231, -11.8288, -19.9716],
         [-16.8807,   5.8885,   0.1881,  ...,  -3.7045,  -6.0659, -16.8421]],
        [[     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan]]],
       device='cuda:0')

This can be fixed by changing the attn_implementation to anything except sdpa

Expected behavior

Using padding should not result in NaN for normal inputs to gemma 2 2b

@chanind chanind added the bug label Aug 2, 2024
@qubvel
Copy link
Member

qubvel commented Aug 2, 2024

Hi @chanind, thanks for reporting the issue!

This is indeed a problem of scaled_dot_product_attention in PyTorch

The cause of nan is how softmax is computed over full-masked rows in the attention mask and I hope it will be fixed in future versions of PyTorch, here is a related PR

Also, a similar issue has been reported previously

Besides switching to eager/flash_attnetion_2 you could also try

  1. Use float16 dtype.
model = AutoModelForCausalLM.from_pretrained(
     "google/gemma-2-2b", device_map="auto", torch_dtype=torch.float16
)
  1. Modify attn_mask min value.

As suggested in the above issue, we can modify attn_mask to use another min value instead of torch.finfo(dtype).min, for example, torch.finfo(dtype).min / 2. To apply this, find min_dtype = torch.finfo(dtype).min in gemma modeling file and replace it with torch.finfo(dtype).min / 2.

Meanwhile, we will try to fix it on our side, thanks!

ArthurZucker added a commit that referenced this issue Aug 3, 2024
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Aug 3, 2024

More than this, it's expected as the sdpa path does not support logit soft-capping (For Gemma2).
We do already take into account the sdpa bug when creating the mask @qubvel see here:

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None

Which should be propagated to Gemma2. (it was not there for some reason my bad here)

@ArthurZucker
Copy link
Collaborator

Related to #31303

@qubvel
Copy link
Member

qubvel commented Aug 3, 2024

@ArthurZucker thanks for the updated info!

@yaolu-zjut
Copy link

Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan:
image
so what's wrong? I using the same code to finetune llama3-8b and it works well.
This is my settings:
image

@EMZEDI
Copy link

EMZEDI commented Aug 15, 2024

Same issue here running the code for hooking the activations of the model. Using float16 made it work.

@ArthurZucker
Copy link
Collaborator

Hey! Make sure you are using eager or flash_attention_2 not sdpa!

@Shengyun-Si
Copy link

Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan: image so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings: image

hi i have the same issue. How do you solve it? 😊

@yaolu-zjut
Copy link

Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan: image so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings: image

hi i have the same issue. How do you solve it? 😊

Hi, I just use eager instead of sdpa like this: model = AutoModelForCausalLM.from_pretrained(args.prune_model_path,
trust_remote_code=True, device_map=device_map, attn_implementation="eager"
)

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants