Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T5] enable T5 fp16 #9487

Merged
merged 2 commits into from
Jan 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,11 @@ def forward(
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights

# clamp inf values to enable fp16 training
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the -1000?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be on the safer side, setting it to the exact max value might again lead to inf values in subsequent layers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okey just noticed that we do the same in Bart as well

hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
# the actual query length is unknown for cross attention
Expand All @@ -661,6 +666,10 @@ def forward(
output_attentions=output_attentions,
)
hidden_states = cross_attention_outputs[0]
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
Expand All @@ -670,6 +679,9 @@ def forward(

# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)

outputs = outputs + (present_key_value_state,) + attention_outputs
Expand Down