Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T5/MT5] resolve inf/nan under amp (mixed precision) #10956

Closed
wants to merge 14 commits into from
Closed

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Mar 29, 2021

As reported in multiple issues t5/mt5 models produce loss of nan under mixed precision training, starting with t5-large and mt5-small and up. This PR is an attempt to fix this issue. This is crucial for DeepSpeed where it's always mixed precision training.

I spent some time with the debugger and the new detect_overflow helper util (added in this PR) and discovered that the best place to fix the whole problem is to not T5LayerFF in mixed precision. This slightly slows things down/consumes more gpu memory, but no longer requires clamping and running after ever overflowing hidden_states.

This PR:

  • turns autocast off during T5LayerFF if run under amp
  • removes the previous attempt to clamp the values as it now works without it
  • introduces debug_utils.py with a helper function detect_overflow which is super-handy for tracking overflows automatically (as it's silent if all goes well). It also has some extra features, such as reporting a number of large elements - disabled by default.

Important:

  • The fix is only for pytorch built-in amp. apex still has this problem since I haven't researched if the same could be done there, but it's probably a waste of time since apex is being phased out. And deepspeed doesn't use amp so it's till affected.

Variations

Other possible variations to this solution:

  1. to do the autocast disabling dynamically. That is trying with autocast and checking if any elements of output are inf (not sure of the overhead) and re-running this layer in full fp32 and setting a flag to continue in fp32 from then on. Here the main price will be paid by models that don't need this workaround, but they will gain but not having autocast turned off - so it might still be a beneficial solution to all
  2. give users a switch to turn this feature on if they discover they need it - or have it on by default and allow users to turn it off if they "know what they are doing".

I am suggesting this since I don't know if all t5/mt5 models are impacted. Definitely t5-small doesn't need this.

Penalizing large activation

See the details comment: #10956 (comment)

@@ -1578,6 +1618,15 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
             loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
             # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

+            # z_loss
+            log_z = lm_logits.view(-1).logsumexp(-1)
+            z_loss = 7e-5
+            loss_extra = z_loss*log_z.square()
+            #z_loss = 1e-5
+            #loss_extra = z_loss*log_z.pow(3)
+            #print(f"loss={loss}, loss_extra={loss_extra}")
+            loss += loss_extra

Questions:

  • If this solution solves the problem at large and is accepted then we probably should document somewhere in t5/mt5 docs that it won't run AMP 100%?

  • Test is needed: any suggestions to how we could write a test that is not too big and still gets nans prior to this PR? t5-small and t5-base don't have this problem (at least with a small sample), in my experiments the first model that gets inf/nan on the first batch is mt5-small (1.2GB), so my minimal test is:

rm -rf output_dir; CUDA_VISIBLE_DEVICES=0 USE_TF=0 PYTHONPATH=src python examples/seq2seq/run_translation.py \
--model_name_or_path google/mt5-small --do_train  --source_lang en --target_lang ro --dataset_name wmt16 \
--dataset_config_name ro-en --output_dir output_dir --per_device_train_batch_size=4  --logging_step 2 --save_steps 0 \
--fp16 --max_train_samples 10 --save_total_limit 0 --save_strategy no

We can then run this as a test and check for nan in loss reports.

But the 1.2GB download is somewhat big even for @slow tests.
edit: @LysandreJik says it's not a problem since we are now caching the models on the test machine.

If it is ok I will just stick this with all the extended tests under examples/tests/trainer/test_trainer_ext.py where we have a setup for this type of full application-based tests.

  • I also know some users mentioned that inf may happen much later in the game. I haven't run very long tests.

FYI, there is another solution posted by @ibeltagy here: #14189 (comment) it is based on custom scaling.


TODO:

  • I left all the debug prints in place so that you could experiment with it easily - will remove when this is approved to be a good change

Related discussions:

Fixes: #10830
Fixes: #10819

@patrickvonplaten, @patil-suraj, @LysandreJik

Comment on lines 308 to 327
def _forward(self, hidden_states):
detect_overflow(hidden_states, "T5LayerFF: 1")
forwarded_states = self.layer_norm(hidden_states)
detect_overflow(forwarded_states, "T5LayerFF: 2")
forwarded_states = self.DenseReluDense(forwarded_states)
detect_overflow(forwarded_states, "T5LayerFF: 3")
hidden_states = hidden_states + self.dropout(forwarded_states)
detect_overflow(hidden_states, "T5LayerFF: 5")
return hidden_states

def forward(self, hidden_states):
# many t5/mt5 models are trained in bfloat16 and don't do well under mixed precision (fp16).
# It appears that it's enough to disable autocast for this FF layer to avoid inf/nan
# problems for the whole model
if torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
return self._forward(hidden_states)
else:
return self._forward(hidden_states)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core of the change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this!

@LysandreJik
Copy link
Member

But the 1.2GB download is somewhat big even for @slow tests.

The downloads are cached on a shared disk across slow self-hosted runners, so that's not an issue!

@stas00
Copy link
Contributor Author

stas00 commented Mar 30, 2021

Before I approached this problem, I did a bit of a study on the bfloat16 vs float16 properties. This is not fully complete, but you can see most of the useful data here: https://github.com/stas00/ml-ways/blob/master/numbers/bfloat16-vs-float16-study.ipynb

Comments/requests/suggestions are welcome though. It's a bit on a terse-side.

@stas00
Copy link
Contributor Author

stas00 commented Apr 5, 2021

I spent some more time staring at the numbers, as I think @patrickvonplaten mentioned in one of the related threads, something trained in bfloat16 isn't going to work with float16. You can see why by looking at this debug output:

min=-2.77e+05 max= 2.75e+05 var= 5.45e+07 mean= 5.16e+01 (T5Stack loop start)
min=-2.77e+05 max= 2.75e+05 var= 5.45e+07 mean= 5.16e+01 (T5Block)
min=-2.77e+05 max= 2.75e+05 var= 5.45e+07 mean= 5.16e+01 (T5LayerNorm)
min= 1.31e+06 max= 6.90e+08 var= 9.52e+15 mean= 5.45e+07 (T5LayerNorm variance)
min=-1.46e+01 max= 1.46e+01 var= 1.00e+00 mean=-2.69e-03 (T5LayerNorm hidden_states)
min=-1.46e+01 max= 1.46e+01 var= 1.00e+00 mean=-2.69e-03 (T5LayerNorm hidden_states before return)
min=-2.76e+05 max= 2.74e+05 var= 5.41e+07 mean= 4.83e+01 (T5Block after T5LayerSelfAttention)
min=-2.76e+05 max= 2.74e+05 var= 5.41e+07 mean= 4.83e+01 (T5LayerNorm)
min= 1.38e+06 max= 6.86e+08 var= 9.37e+15 mean= 5.41e+07 (T5LayerNorm variance)
min=-1.45e+01 max= 1.46e+01 var= 1.00e+00 mean=-2.98e-03 (T5LayerNorm hidden_states)
min=-1.45e+01 max= 1.46e+01 var= 1.00e+00 mean=-2.98e-03 (T5LayerNorm hidden_states before return)
min=-2.76e+05 max= 2.73e+05 var= 5.40e+07 mean= 3.93e+01 (T5Block before T5LayerFF)
min=-2.76e+05 max= 2.73e+05 var= 5.40e+07 mean= 3.93e+01 (T5LayerFF: 1)
min=-2.76e+05 max= 2.73e+05 var= 5.40e+07 mean= 3.93e+01 (T5LayerNorm)
min= 1.61e+06 max= 6.84e+08 var= 9.28e+15 mean= 5.40e+07 (T5LayerNorm variance)
min=-1.44e+01 max= 1.46e+01 var= 1.00e+00 mean=-5.14e-03 (T5LayerNorm hidden_states)
min=-1.44e+01 max= 1.46e+01 var= 1.00e+00 mean=-5.14e-03 (T5LayerNorm hidden_states before return)
min=-2.47e+00 max= 3.03e+00 var= 4.43e-02 mean=-8.23e-05 (T5LayerFF: 2)
min=-1.70e-01 max= 4.95e+01 var= 6.34e-01 mean= 3.00e-01 (gelu 1)
min=-3.70e+02 max= 3.93e+02 var= 3.79e+02 mean= 2.79e-01 (gelu 2)
min=-4.71e+03 max= 3.67e+03 var= 1.89e+03 mean=-3.80e-01 (gelu 3)
min=-5.23e+03 max= 4.08e+03 var= 2.21e+03 mean=-4.75e-01 (gelu 4)
min=-7.11e+04 max= 5.32e+04 var= 8.27e+06 mean=-1.36e+02 (gelu 5)
min=-7.11e+04 max= 5.32e+04 var= 8.27e+06 mean=-1.36e+02 (T5LayerFF: 3)
min=-2.61e+05 max= 2.68e+05 var= 4.41e+07 mean=-1.04e+02 (T5LayerFF: 5)
min=-2.61e+05 max= 2.68e+05 var= 4.41e+07 mean=-1.04e+02 (T5Block after T5LayerFF)
min=-2.61e+05 max= 2.68e+05 var= 4.41e+07 mean=-1.04e+02 (T5Stack loop end)
min=-2.61e+05 max= 2.68e+05 var= 4.41e+07 mean=-1.04e+02 (T5LayerNorm)
min= 2.99e+06 max= 6.12e+08 var= 5.65e+15 mean= 4.41e+07 (T5LayerNorm variance)
min=-1.45e+01 max= 1.62e+01 var= 1.00e+00 mean=-2.27e-02 (T5LayerNorm hidden_states)
min=-1.45e+01 max= 1.62e+01 var= 1.00e+00 mean=-2.27e-02 (T5LayerNorm hidden_states before return)

Because bfloat16 lacks precision - it trained itself to compensate for this by switching to the range of large numbers. If you look at the numbers above you can see that many of them are a way beyond fp16 range, which can only do +-64K.

So if I understand the nature of this problem correctly expecting this to work is a bit of fantasy. But of course, let's try to do our best to come as close to the solution as possible.

I found that it's enough to cancel autocast just for self.DenseReluDense for the simple case to not produce NaN.

@stas00
Copy link
Contributor Author

stas00 commented Apr 15, 2021

@yuvalkirstain, let's switch the discussion to the actual PR

wrt your newly discovered overflow.

Please try to add this penalizing for large logits:

@@ -1578,6 +1618,15 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
             loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
             # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

+            # z_loss
+            log_z = lm_logits.view(-1).logsumexp(-1)
+            z_loss = 7e-5
+            loss_extra = z_loss*log_z.square()
+            #z_loss = 1e-5
+            #loss_extra = z_loss*log_z.pow(3)
+            #print(f"loss={loss}, loss_extra={loss_extra}")
+            loss += loss_extra

May need some tuning for z_loss factor for best convergence. The recommended one is 1e-4, so I've experimented with a few. Also tried the pow(3) instead of pow(2).

It seem that the network gets the hint within just a 100 steps - loss_extra drops down very quickly.

Perhaps this was the missing piece?

@stas00
Copy link
Contributor Author

stas00 commented Apr 27, 2021

Here is the output of the proposed overflow/underflow detector in progress tool for mt5. This is prior to any modifications proposed in this PR. So one can see the progression as the weights and activations change from forward to forward.

rm -rf output_dir; CUDA_VISIBLE_DEVICES=0 USE_TF=0 PYTHONPATH=src \
python examples/pytorch/translation/run_translation.py --model_name_or_path google/mt5-small --do_train \
--source_lang en --target_lang ro --dataset_name \
wmt16 --dataset_config_name ro-en --output_dir output_dir --per_device_train_batch_size=4  --logging_step 2 --save_steps 0 \
--fp16 --max_train_samples 10 --save_total_limit 0 --save_strategy no --debug underflow_overflow
Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min  abs max  metadata
                  encoder.block.1.layer.1.DenseReluDense.dropout Dropout
0.00e+00 2.57e+02 input[0]
0.00e+00 2.85e+02 output
                  encoder.block.1.layer.1.DenseReluDense.wo Linear
4.80e-06 8.62e+00 weight
0.00e+00 2.85e+02 input[0]
8.50e-05 1.53e+03 output
                  encoder.block.1.layer.1.DenseReluDense T5DenseGatedGeluDense
0.00e+00 2.04e+00 input[0]
8.50e-05 1.53e+03 output
                  encoder.block.1.layer.1.dropout Dropout
8.50e-05 1.53e+03 input[0]
0.00e+00 1.70e+03 output
                  encoder.block.1.layer.1 T5LayerFF
0.00e+00 1.50e+03 input[0]
6.78e-04 3.15e+03 output
                  encoder.block.1 T5Block
0.00e+00 1.40e+03 input[0]
6.78e-04 3.15e+03 output[0]
             None output[1]
2.25e-01 1.00e+04 output[2]
                  encoder.block.2.layer.0.layer_norm T5LayerNorm
6.54e-02 2.75e-01 weight
6.78e-04 3.15e+03 input[0]
5.75e-06 2.12e+00 output
                  encoder.block.2.layer.0.SelfAttention.q Linear
3.75e-08 3.40e-01 weight
5.75e-06 2.12e+00 input[0]
2.21e-06 1.20e+00 output
                  encoder.block.2.layer.0.SelfAttention.k Linear
4.84e-08 2.62e+00 weight
5.75e-06 2.12e+00 input[0]
5.47e-05 1.40e+01 output
                  encoder.block.2.layer.0.SelfAttention.v Linear
7.21e-06 2.59e+00 weight
5.75e-06 2.12e+00 input[0]
1.20e-04 7.56e+00 output
                  encoder.block.2.layer.0.SelfAttention.o Linear
6.65e-06 1.44e+01 weight
0.00e+00 5.30e+00 input[0]
5.20e-04 2.66e+02 output
                  encoder.block.2.layer.0.SelfAttention T5Attention
5.75e-06 2.12e+00 input[0]
5.20e-04 2.66e+02 output[0]
             None output[1]
2.25e-01 1.00e+04 output[2]
                  encoder.block.2.layer.0.dropout Dropout
5.20e-04 2.66e+02 input[0]
0.00e+00 2.96e+02 output
                  encoder.block.2.layer.0 T5LayerSelfAttention
6.78e-04 3.15e+03 input[0]
2.65e-04 3.42e+03 output[0]
             None output[1]
2.25e-01 1.00e+04 output[2]
                  encoder.block.2.layer.1.layer_norm T5LayerNorm
8.69e-02 4.18e-01 weight
2.65e-04 3.42e+03 input[0]
1.79e-06 4.65e+00 output
                  encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
                  encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
                  encoder.block.2.layer.1.DenseReluDense.dropout Dropout
0.00e+00 8.76e+03 input[0]
0.00e+00 9.74e+03 output
                  encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
                  encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
                  encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00      inf output

@dblakely
Copy link
Contributor

Hi there, I'm wondering what the current status of this is, as my team would benefit from a fix to fp16 issue with large T5 models. And is there anything we could do to help to move the PR along?

In the mean time, it should be sufficient to simply disable autocast for the DenseReluDense, correct?

@stas00
Copy link
Contributor Author

stas00 commented May 11, 2021

Hi there, I'm wondering what the current status of this is, as my team would benefit from a fix to fp16 issue with large T5 models. And is there anything we could do to help to move the PR along?

@yuvalkirstain, who is one of the original reporters mentioned elsewhere that he still had an issue during the long training, so I was waiting for him to provide more details.

In the mean time, it should be sufficient to simply disable autocast for the DenseReluDense, correct?

If you're not using deepspeed, then yes, that is all that is needed. At least for the tests I have done. But they weren't long.

Perhaps you could test this PR and report back if it solves your problem?

I'm not sure if I should remove the clamping or not.

I cleaned up the PR to remove all the debug noise, so it's very simple now.

@stas00
Copy link
Contributor Author

stas00 commented Dec 3, 2021

So heads up to all those watching this PR - if you have Ampere GPUs start using --bf16 which was just added (i.e. use master) and the overflow problem will be no more: https://huggingface.co/docs/transformers/master/performance#bf16

p.s. I haven't actually tested that it's so with this particular issue, so if you do and find that it is not so, please kindly flag this to me.

@stas00
Copy link
Contributor Author

stas00 commented Dec 18, 2021

So I got a chance to test the new --bf16 flag on this issue with RTX-3090 (Ampere) and now mt5 doesn't overflow:

rm -rf output_dir; CUDA_VISIBLE_DEVICES=0 USE_TF=0 PYTHONPATH=src python \
examples/pytorch/translation/run_translation.py --model_name_or_path google/mt5-small --do_train \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en --output_dir \
output_dir --per_device_train_batch_size=4 --logging_step 2 --save_steps 0 --max_train_samples 10 \
--save_total_limit 0 --save_strategy no --bf16

***** train metrics *****
  epoch                    =        3.0
  train_loss               =    28.7758
  train_runtime            = 0:00:01.94
  train_samples            =         10
  train_samples_per_second =     15.458
  train_steps_per_second   =      4.637

we get the same outcome with fp32 (i.e. w/o --bf16).

With --fp16 we still overflow (no surprise here, I have just re-checked that):

rm -rf output_dir; CUDA_VISIBLE_DEVICES=0 USE_TF=0 PYTHONPATH=src python \
examples/pytorch/translation/run_translation.py --model_name_or_path google/mt5-small --do_train \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en --output_dir \
output_dir --per_device_train_batch_size=4 --logging_step 2 --save_steps 0 --max_train_samples 10 \
--save_total_limit 0 --save_strategy no --fp16

***** train metrics *****
  epoch                    =        3.0
  train_loss               =        0.0
  train_runtime            = 0:00:01.74
  train_samples            =         10
  train_samples_per_second =      17.24
  train_steps_per_second   =      5.172

@GaryYufei
Copy link

This PR forces T5 FF Layer in fp32. With this change, there is almost no benefit to training in fp16. The memory usage and training speed improvements are very limited.

Hi, same issue here. Did you find a way to fix it?

@Liangtaiwan
Copy link
Contributor

@GaryYufei
I don't think it's almost impossible to fix this issue.
The best way is to use GPUs that support bf16 training.
You may also try to use @tlkh proposed method but not sure there would be any side-effect.

In my case, I choose to use fp32 to finetune the model.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@TianHongZXY
Copy link

Is this problem solved in the latest release of transformers?

@stas00
Copy link
Contributor Author

stas00 commented Apr 6, 2022

It's not really a problem in transformers per se, but a limitation of the model.

This PR didn't get merged as it helped only in some cases and it would introduce a slowdown.

This thread contains various possible workarounds, but the best solution at the moment is to use bf16-able hardware to finetune t5 and any other bf16-pretrained models (Ampere GPUs or TPUs)

@TianHongZXY
Copy link

I see, thank you for your answer

killight98 added a commit to killight98/transformers that referenced this pull request May 29, 2022
@djaym7
Copy link

djaym7 commented Sep 6, 2022

Can someone approve this ? I'm getting nan values .. on main branch

@stas00
Copy link
Contributor Author

stas00 commented Sep 6, 2022

It hasn't been merged because it's not the ideal solution as it introduces a degradation in performance (scales to fp32 = more memory used) and it doesn't work always resolve the problem.

This is a curse of many bf16-pre-trained models used in fp16 mode and not just of T5 and its derivatives.

Do you by chance have access to Ampere gpus and are able to use bf16 instead of fp16 - this would solve the problem w/o changing the code. #10956 (comment)

@silencio94
Copy link

In my case, the issue with t5 training, causing nan values in mix precision (using torch native amp), has been resolved following this PR. Thank you very much. Have a great day!

@ArthurZucker
Copy link
Collaborator

Closing as this PR is super old and not planned

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

getting nans with t5-large + fix mt5 getting nans with fp16