-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T5Model in fp16 still yield nan with more complex examples #4586
Comments
I got the same issue - seems to happen with the larger models (t5 small is fine) |
I can reproduce the error - will investigate :-) |
Okey this took me quite some time to figure out... So what happens is the following. When setting all modules in half as is done in the code snippet above, the following happens. At some point in line: transformers/src/transformers/modeling_t5.py Line 188 in acaa2e6
the tensor layer_output contains inf values and then later in:transformers/src/transformers/modeling_t5.py Line 156 in acaa2e6
nan values enter the game...
I don't really think this is a bug in T5, but it's just due to T5's rather unstable architecture. So using your code above and using the from transformers import T5Model
from apex import amp
import torch
model = T5Model.from_pretrained("t5-base").cuda().eval()
model = amp.initialize(model, opt_level="O3")
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2] # nan output yields the same output consisting of from transformers import T5Model
from apex import amp
import torch
model = T5Model.from_pretrained("t5-base").cuda().eval()
model = amp.initialize(model, opt_level="O1")
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2] # valid output however does not produce any Also pinging @mfuntowicz, @julien-c and @LysandreJik for verification |
@patrickvonplaten Even with O1 I tried fine-tuning T5-base, and in less than 100 iterations, it will converge to nan values quickly. Seems like the stability of this model is poor. Perhaps first few iterations of fine-tuning require FP32. |
|
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
Having the same issue with loss going to |
Ran into this issue and found a workaround to get FP16 training working.
|
@leecming Have you also tried the fix with |
Great qusetion @j-min - I actually didn't find the time yet to test the "new" t5 model with fp16. It might very well be that the following models work fine with fp16: |
@patrickvonplaten @leecming I'm trying the fix as below. class T5DenseReluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = F.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
with autocast(enabled=False):
hidden_states = self.wo(hidden_states)
return hidden_states
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
with autocast(enabled=False):
hidden_states = self.wo(hidden_states)
return hidden_states Btw it results in the error import torch
from torch.cuda.amp import autocast
from transformers import T5Model
model = T5Model.from_pretrained("t5-base").cuda().eval()
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2]
with autocast():
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
loss = out.last_hidden_state.exp().mean() |
Oh adding class T5DenseReluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = F.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
with autocast(enabled=False):
hidden_states = hidden_states.to(torch.float32)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
with autocast(enabled=False):
hidden_states = hidden_states.to(torch.float32)
hidden_states = self.wo(hidden_states)
return hidden_states import torch
from torch.cuda.amp import autocast
from transformers import T5Model
model = T5Model.from_pretrained("t5-base").cuda().eval()
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2]
with autocast():
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
loss = out.last_hidden_state.exp().mean()
print(loss)
>>> tensor(1.1017, device='cuda:0', grad_fn=<MeanBackward0>) |
This is actually a topic I wanted to look into more closely and didn't manage to do so time-wise...maybe next week. But in short, one should try to train a whole T5 model with your suggested fix. What I would recommend doing is to take your guys' fix from above and open a PR with it. Then with this PR we should fine-tune a whole t5 model on some task, e.g. using the Seq2SeqTrainer. E.g. one could adapt this script:https://colab.research.google.com/drive/1Ekd5pUeCX7VOrMx94_czTkwNtLN32Uyu?usp=sharing and instead of using a also cc @patil-suraj in case he has better pointers/ideas |
I'll try to do a run next week though :-) |
It’s not a good fix since it relies on a specific AMP implementation (autocast) and wouldn’t work on others (e.g., Nvidia APEX). It also uses more memory than a clean AMP implementation. A cleaner quick fix would be to copy BERT’s gradient checkpointing code and train in FP32 mode with checkpointing. Also, Nvidia with the latest Ampere cards has started supporting bf16 which is good news. |
I am having the same issue with mt5-small getting nan with deepspeed, I really appreciate any advice on this. I am having really a hard time with it, thanks a lot |
see: #10830 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
anyone coming after some years, try this https://huggingface.co/google/umt5-small instead |
no luck with https://huggingface.co/google/umt5-small as well even though I was training using |
I got into this w/ T5-3b https://huggingface.co/t5-3b/tree/main, using the more recent T5ForSequenceClassification head. I thought it was due to that newer head but now I'm seeing the issue's been more profound. I'll see what my fp32 fine-tuning gives tomorrow, as I believe no other comprehensive solution has been put into place just yet. |
🐛 Bug
Hello, thank you for the recent PR with fp16 fixes. It seems to work well with short inputs, but once the model is fed with some more complex data it still yields nans.
Information
Model I am using: T5
Language I am using the model on: English
The problem arises when using:
The tasks I am working on is:
To reproduce
Run the code:
output:
Expected behavior
Output with non-nan values.
Environment info
transformers
version: 2.10.0The text was updated successfully, but these errors were encountered: