Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5Model in fp16 still yield nan with more complex examples #4586

Closed
2 of 4 tasks
rpowalski opened this issue May 25, 2020 · 21 comments
Closed
2 of 4 tasks

T5Model in fp16 still yield nan with more complex examples #4586

rpowalski opened this issue May 25, 2020 · 21 comments
Assignees

Comments

@rpowalski
Copy link

🐛 Bug

Hello, thank you for the recent PR with fp16 fixes. It seems to work well with short inputs, but once the model is fed with some more complex data it still yields nans.

Information

Model I am using: T5

Language I am using the model on: English

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Run the code:

from transformers import T5Model
import torch

model = T5Model.from_pretrained("t5-base").cuda().half().eval()
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()

out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2]

output:

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       dtype=torch.float16, grad_fn=<SliceBackward>)

Expected behavior

Output with non-nan values.

Environment info

  • transformers version: 2.10.0
  • Platform: Linux-4.15.0-88-generic-x86_64-with-debian-buster-sid
  • Python version: 3.6.10
  • PyTorch version (GPU?): 1.4.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no
@patrickvonplaten patrickvonplaten self-assigned this May 27, 2020
@calclavia
Copy link

calclavia commented May 29, 2020

I got the same issue - seems to happen with the larger models (t5 small is fine)

@patrickvonplaten
Copy link
Contributor

I can reproduce the error - will investigate :-)

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jun 5, 2020

Okey this took me quite some time to figure out...

So what happens is the following. When setting all modules in half as is done in the code snippet above, the following happens. At some point in line:

layer_output = hidden_states + self.dropout(y)

the tensor layer_output contains inf values and then later in:
x = x / torch.sqrt(variance + self.variance_epsilon)

nan values enter the game...

I don't really think this is a bug in T5, but it's just due to T5's rather unstable architecture. model.half() essentially corresponds to an apex level O3: https://nvidia.github.io/apex/amp.html#o3-fp16-training which in itself tends to become unstable...

So using your code above and using the apex package instead of calling half() on the model, you can notice the following. The code snippet which is essentially the same as yours:

from transformers import T5Model
from apex import amp
import torch

model = T5Model.from_pretrained("t5-base").cuda().eval()
model = amp.initialize(model, opt_level="O3") 

inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()

out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2]  # nan output

yields the same output consisting of nan values. The same happens for opt_level O2.
Using the recommended O1 level of optimization:

from transformers import T5Model
from apex import amp
import torch

model = T5Model.from_pretrained("t5-base").cuda().eval()
model = amp.initialize(model, opt_level="O1") 

inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()

out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2]  # valid output

however does not produce any nan values. As far as I know O1 is also the recommended setting: https://nvidia.github.io/apex/amp.html#o1-mixed-precision-recommended-for-typical-use .
As far as I know O1 can already greatly speed up your calculations and save quite some memory, so that I would recommend going for this.

Also pinging @mfuntowicz, @julien-c and @LysandreJik for verification

@calclavia
Copy link

@patrickvonplaten Even with O1 I tried fine-tuning T5-base, and in less than 100 iterations, it will converge to nan values quickly. Seems like the stability of this model is poor. Perhaps first few iterations of fine-tuning require FP32.

@sshleifer
Copy link
Contributor

sshleifer commented Jun 11, 2020

I am having issues even in fp32 with everything besides t5-small.
I am having issues in O1 with t5-large and t5-base.

@stale
Copy link

stale bot commented Sep 11, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Sep 11, 2020
@stale stale bot closed this as completed Sep 19, 2020
@lastmansleeping
Copy link

Having the same issue with loss going to nan when fine-tuning tf-base with fp16. tf-small works fine though.

@ghost
Copy link

ghost commented Dec 15, 2020

Ran into this issue and found a workaround to get FP16 training working.
T5DenseGatedGeluDense doesn't play nice with FP16, specifically the final dense layer to resize from d_ff to d_model.
I used pytorch's autocast/gradscaler mixed precision implementation and created an exception for that specific dense layer.

class T5DenseGatedGeluDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.gelu_act = ACT2FN["gelu_new"]

    def forward(self, hidden_states):
        hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        with autocast(enabled=False):
            hidden_states = self.wo(hidden_states)
        return hidden_states

@j-min
Copy link

j-min commented Dec 18, 2020

@leecming Have you also tried the fix with T5DenseReluDense?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 18, 2020

Great qusetion @j-min - I actually didn't find the time yet to test the "new" t5 model with fp16. It might very well be that the following models work fine with fp16:
https://huggingface.co/models?search=mt5
and
https://huggingface.co/models?search=t5-v1

@stale stale bot removed the wontfix label Dec 18, 2020
@j-min
Copy link

j-min commented Dec 18, 2020

@patrickvonplaten @leecming I'm trying the fix as below.

class T5DenseReluDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, hidden_states):
        hidden_states = self.wi(hidden_states)
        hidden_states = F.relu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        with autocast(enabled=False):
            hidden_states = self.wo(hidden_states)
        return hidden_states


class T5DenseGatedGeluDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.gelu_act = ACT2FN["gelu_new"]

    def forward(self, hidden_states):
        hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        with autocast(enabled=False):
            hidden_states = self.wo(hidden_states)
        return hidden_states

Btw it results in the error expected scalar type Half but found Float, since hidden_states parameters are float while self.wo parameters are half.
Could you please guide how I bypass the error?

import torch
from torch.cuda.amp import autocast
from transformers import T5Model

model = T5Model.from_pretrained("t5-base").cuda().eval()
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()

out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2]

with autocast():
    out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
    loss = out.last_hidden_state.exp().mean()

@j-min
Copy link

j-min commented Dec 18, 2020

Oh adding hidden_states = hidden_states.to(torch.float32) worked, never mind.
Is there a more concrete script to check if this fixes T5's fp16 training? @patrickvonplaten

class T5DenseReluDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, hidden_states):
        hidden_states = self.wi(hidden_states)
        hidden_states = F.relu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        with autocast(enabled=False):
            hidden_states = hidden_states.to(torch.float32)
            hidden_states = self.wo(hidden_states)
        return hidden_states


class T5DenseGatedGeluDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.gelu_act = ACT2FN["gelu_new"]

    def forward(self, hidden_states):
        hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        with autocast(enabled=False):
            hidden_states = hidden_states.to(torch.float32)
            hidden_states = self.wo(hidden_states)
        return hidden_states
import torch
from torch.cuda.amp import autocast
from transformers import T5Model

model = T5Model.from_pretrained("t5-base").cuda().eval()
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()

out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2]

with autocast():
    out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
    loss = out.last_hidden_state.exp().mean()

print(loss)
>>> tensor(1.1017, device='cuda:0', grad_fn=<MeanBackward0>)

@patrickvonplaten
Copy link
Contributor

This is actually a topic I wanted to look into more closely and didn't manage to do so time-wise...maybe next week.

But in short, one should try to train a whole T5 model with your suggested fix.

What I would recommend doing is to take your guys' fix from above and open a PR with it. Then with this PR we should fine-tune a whole t5 model on some task, e.g. using the Seq2SeqTrainer.

E.g. one could adapt this script:https://colab.research.google.com/drive/1Ekd5pUeCX7VOrMx94_czTkwNtLN32Uyu?usp=sharing and instead of using a Bert2Bert model one could just use a google/t5v1_1-small or base model and see whether there are any problem in training.

also cc @patil-suraj in case he has better pointers/ideas

@patrickvonplaten
Copy link
Contributor

I'll try to do a run next week though :-)

@ghost
Copy link

ghost commented Dec 18, 2020

It’s not a good fix since it relies on a specific AMP implementation (autocast) and wouldn’t work on others (e.g., Nvidia APEX). It also uses more memory than a clean AMP implementation.

A cleaner quick fix would be to copy BERT’s gradient checkpointing code and train in FP32 mode with checkpointing.

Also, Nvidia with the latest Ampere cards has started supporting bf16 which is good news.

@dorost1234
Copy link

dorost1234 commented Mar 19, 2021

I am having the same issue with mt5-small getting nan with deepspeed, I really appreciate any advice on this. I am having really a hard time with it, thanks a lot
@patrickvonplaten @patil-suraj @sgugger Do you mind sharing the current state of mt5 training with fp16? thanks a lot

@patrickvonplaten
Copy link
Contributor

see: #10830

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@sd3ntato
Copy link

sd3ntato commented Jul 4, 2023

anyone coming after some years, try this https://huggingface.co/google/umt5-small instead

@prasannakdev0
Copy link

no luck with https://huggingface.co/google/umt5-small as well even though I was training using FP32

@alisatl
Copy link

alisatl commented Oct 9, 2023

I got into this w/ T5-3b https://huggingface.co/t5-3b/tree/main, using the more recent T5ForSequenceClassification head. I thought it was due to that newer head but now I'm seeing the issue's been more profound.

I'll see what my fp32 fine-tuning gives tomorrow, as I believe no other comprehensive solution has been put into place just yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants