Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T5] enable T5 fp16 #9487

Merged
merged 2 commits into from
Jan 12, 2021
Merged

[T5] enable T5 fp16 #9487

merged 2 commits into from
Jan 12, 2021

Conversation

patil-suraj
Copy link
Contributor

What does this PR do?

This PR enables fp16 for T5 models, by clamping hidden states to the max value of the current data type.

As detailed in #9295, T5 produces large (inf) activations at 3 places

  1. Output of T5LayerFF
  2. Output of T5LayerSelfAttention
  3. Output of T5LayerCrossAttention

To avoid these inf activations this PR clamps the hidden_states after above 3 outputs

@@ -640,6 +640,11 @@ def forward(
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights

# clamp inf values
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the -1000?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be on the safer side, setting it to the exact max value might again lead to inf values in subsequent layers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okey just noticed that we do the same in Bart as well

@@ -640,6 +640,11 @@ def forward(
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights

# clamp inf values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe improve comment slightly:

Suggested change
# clamp inf values
# clamp inf values to enable fp16 training

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jan 10, 2021

This is great!

@exelents
Copy link

exelents commented Jan 10, 2021

Dear @patil-suraj
Your PR works well for t5 model, thank you for your work.
But now I tried new t5 model version released recently by Google: google/t5-v1_1-xl
The same code after loading google/t5-v1_1-xl instead of t5-3b is going to return a lot "overflow" errors.

Can you tell me, should your code fix fp16 on google/t5-v1_1-xl model?
Here is training code:
https://github.com/exelents/try_t5_qa
Run ./run-qa-3b.sh

Upd: I run my code on Transformers's branch from your current PR #9487 merged with PR #9211 needed for deepspeed integration.
Can you confirm a problem, or it's just mine?

@patrickvonplaten
Copy link
Contributor

Dear @patil-suraj
Your PR works well for t5 model, thank you for your work.
But now I tried new t5 model version released recently by Google: google/t5-v1_1-xl
The same code after loading google/t5-v1_1-xl instead of t5-3b is going to return a lot "overflow" errors.

Can you tell me, should your code fix fp16 on google/t5-v1_1-xl model?
Here is training code:
https://github.com/exelents/try_t5_qa
Run ./run-qa-3b.sh

Upd: I run my code on Transformers's branch from your current PR #9487 merged with PR #9211 needed for deepspeed integration.
Can you confirm a problem, or it's just mine?

Hey @exelents, can you include a code snippet to reproduce your error as well as the full stack trace of your error?

@patil-suraj
Copy link
Contributor Author

@patrickvonplaten , @exelents

as stated in #9432

This fix works for following models and versions, with apex 01 and native amp

  • T5v1: t5-small, t5-base, t5-large
  • T5v1_1: google/t5-v1_1-small, google/t5-v1_1-base
  • MT5: google/mt5-small, google/mt5-base

Just did a small experiment with t5-v1_1-large and it still gives nan loss after 200 steps, so might not work for xl,

also, @exelents by overflow error do you mean the gradient overflow warning thrown by apex ?

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten , @exelents

as stated in #9432

This fix works for following models and versions, with apex 01 and native amp

  • T5v1: t5-small, t5-base, t5-large
  • T5v1_1: google/t5-v1_1-small, google/t5-v1_1-base
  • MT5: google/mt5-small, google/mt5-base

Just did a small experiment with t5-v1_1-large and it still gives nan loss after 200 steps, so might not work for xl,

also, @exelents by overflow error do you mean the gradient overflow warning thrown by apex ?

Ah ok, we still see nan's with t5-v1_1-large then :-/ Do you think this could be fixed by adding one more clamp statement? @patil-suraj

@exelents
Copy link

exelents commented Jan 11, 2021

Hey @exelents, can you include a code snippet to reproduce your error as well as the full stack trace of your error?
My code is here:
https://github.com/exelents/try_t5_qa
It requires deepspeed to run, as well as code from #9211 PR (deepspeed integration) be merged. Use run-qa-3b.sh to test.

Here is error stack:
https://gist.github.com/exelents/10f1d03e61059ddf2dfba7068114c93a
Look at the end - we have a message after every step:
[2021-01-11 16:58:18,163] [INFO] [stage2.py:1361:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 256.0, reducing to 128.0
Wait a second, I'll try to check loss value tensor.

@patil-suraj
Copy link
Contributor Author

patil-suraj commented Jan 11, 2021

Do you think this could be fixed by adding one more clamp statement?

I'm again trying to locate where exactly in the model this happen. In case it's the same as above (first inf then nan ) then we could fix it by adding one more clamp

@exelents
Copy link

I have checked a loss value, and it seems in is not NaN. It got values like "48.7500" or "40.9688" but there are vaild values. Despite that I see messages like "OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1024.0, reducing to 512.0", that it seems means that something bad happened with model's loss.

@sgugger
Copy link
Collaborator

sgugger commented Jan 11, 2021

Attempted loss scale: 1024.0, reducing to 512.0", that it seems means that something bad happened with model's loss.

Those warnings don't mean anything went wrong, it's logical with dynamic loss scaling that some loss scale values are too big at the beginning of training.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing this!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Thanks for working on this @patil-suraj!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants