-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[T5/MT5] resolve inf/nan under amp (mixed precision) #10956
Conversation
def _forward(self, hidden_states): | ||
detect_overflow(hidden_states, "T5LayerFF: 1") | ||
forwarded_states = self.layer_norm(hidden_states) | ||
detect_overflow(forwarded_states, "T5LayerFF: 2") | ||
forwarded_states = self.DenseReluDense(forwarded_states) | ||
detect_overflow(forwarded_states, "T5LayerFF: 3") | ||
hidden_states = hidden_states + self.dropout(forwarded_states) | ||
detect_overflow(hidden_states, "T5LayerFF: 5") | ||
return hidden_states | ||
|
||
def forward(self, hidden_states): | ||
# many t5/mt5 models are trained in bfloat16 and don't do well under mixed precision (fp16). | ||
# It appears that it's enough to disable autocast for this FF layer to avoid inf/nan | ||
# problems for the whole model | ||
if torch.is_autocast_enabled(): | ||
with torch.cuda.amp.autocast(enabled=False): | ||
return self._forward(hidden_states) | ||
else: | ||
return self._forward(hidden_states) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the core of the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with this!
The downloads are cached on a shared disk across slow self-hosted runners, so that's not an issue! |
Before I approached this problem, I did a bit of a study on the bfloat16 vs float16 properties. This is not fully complete, but you can see most of the useful data here: https://github.com/stas00/ml-ways/blob/master/numbers/bfloat16-vs-float16-study.ipynb Comments/requests/suggestions are welcome though. It's a bit on a terse-side. |
I spent some more time staring at the numbers, as I think @patrickvonplaten mentioned in one of the related threads, something trained in
Because So if I understand the nature of this problem correctly expecting this to work is a bit of fantasy. But of course, let's try to do our best to come as close to the solution as possible. I found that it's enough to cancel autocast just for |
@yuvalkirstain, let's switch the discussion to the actual PR wrt your newly discovered overflow. Please try to add this penalizing for large logits:
May need some tuning for It seem that the network gets the hint within just a 100 steps - Perhaps this was the missing piece? |
Here is the output of the proposed overflow/underflow detector in progress tool for mt5. This is prior to any modifications proposed in this PR. So one can see the progression as the weights and activations change from forward to forward.
|
Hi there, I'm wondering what the current status of this is, as my team would benefit from a fix to fp16 issue with large T5 models. And is there anything we could do to help to move the PR along? In the mean time, it should be sufficient to simply disable autocast for the DenseReluDense, correct? |
@yuvalkirstain, who is one of the original reporters mentioned elsewhere that he still had an issue during the long training, so I was waiting for him to provide more details.
If you're not using deepspeed, then yes, that is all that is needed. At least for the tests I have done. But they weren't long. Perhaps you could test this PR and report back if it solves your problem? I'm not sure if I should remove the clamping or not. I cleaned up the PR to remove all the debug noise, so it's very simple now. |
So heads up to all those watching this PR - if you have Ampere GPUs start using p.s. I haven't actually tested that it's so with this particular issue, so if you do and find that it is not so, please kindly flag this to me. |
So I got a chance to test the new
we get the same outcome with fp32 (i.e. w/o With
|
Hi, same issue here. Did you find a way to fix it? |
@GaryYufei In my case, I choose to use fp32 to finetune the model. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Is this problem solved in the latest release of transformers? |
It's not really a problem in This PR didn't get merged as it helped only in some cases and it would introduce a slowdown. This thread contains various possible workarounds, but the best solution at the moment is to use bf16-able hardware to finetune t5 and any other bf16-pretrained models (Ampere GPUs or TPUs) |
I see, thank you for your answer |
Refer to [PR huggingface#10956](huggingface#10956).
Can someone approve this ? I'm getting nan values .. on main branch |
It hasn't been merged because it's not the ideal solution as it introduces a degradation in performance (scales to fp32 = more memory used) and it doesn't work always resolve the problem. This is a curse of many bf16-pre-trained models used in fp16 mode and not just of T5 and its derivatives. Do you by chance have access to Ampere gpus and are able to use bf16 instead of fp16 - this would solve the problem w/o changing the code. #10956 (comment) |
In my case, the issue with t5 training, causing nan values in mix precision (using torch native amp), has been resolved following this PR. Thank you very much. Have a great day! |
Closing as this PR is super old and not planned |
As reported in multiple issues t5/mt5 models produce loss of
nan
under mixed precision training, starting with t5-large and mt5-small and up. This PR is an attempt to fix this issue. This is crucial for DeepSpeed where it's always mixed precision training.I spent some time with the debugger and the new
detect_overflow
helper util (added in this PR) and discovered that the best place to fix the whole problem is to notT5LayerFF
in mixed precision. This slightly slows things down/consumes more gpu memory, but no longer requires clamping and running after ever overflowinghidden_states
.This PR:
autocast
off duringT5LayerFF
if run under ampdebug_utils.py
with a helper functiondetect_overflow
which is super-handy for tracking overflows automatically (as it's silent if all goes well). It also has some extra features, such as reporting a number of large elements - disabled by default.Important:
Variations
Other possible variations to this solution:
autocast
disabling dynamically. That is trying withautocast
and checking if any elements of output areinf
(not sure of the overhead) and re-running this layer in full fp32 and setting a flag to continue in fp32 from then on. Here the main price will be paid by models that don't need this workaround, but they will gain but not havingautocast
turned off - so it might still be a beneficial solution to allI am suggesting this since I don't know if all t5/mt5 models are impacted. Definitely t5-small doesn't need this.
Penalizing large activation
See the details comment: #10956 (comment)
Questions:
If this solution solves the problem at large and is accepted then we probably should document somewhere in t5/mt5 docs that it won't run AMP 100%?
Test is needed: any suggestions to how we could write a test that is not too big and still gets nans prior to this PR?
t5-small
andt5-base
don't have this problem (at least with a small sample), in my experiments the first model that getsinf/nan
on the first batch ismt5-small
(1.2GB), so my minimal test is:We can then run this as a test and check for
nan
in loss reports.But the 1.2GB download is somewhat big even for
@slow
tests.edit: @LysandreJik says it's not a problem since we are now caching the models on the test machine.
If it is ok I will just stick this with all the extended tests under
examples/tests/trainer/test_trainer_ext.py
where we have a setup for this type of full application-based tests.inf
may happen much later in the game. I haven't run very long tests.FYI, there is another solution posted by @ibeltagy here: #14189 (comment) it is based on custom scaling.
TODO:
Related discussions:
Fixes: #10830
Fixes: #10819
@patrickvonplaten, @patil-suraj, @LysandreJik