Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 Gradient Checkpointing #11353

Merged
merged 10 commits into from
Apr 30, 2021
4 changes: 4 additions & 0 deletions src/transformers/models/t5/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class T5Config(PretrainedConfig):
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
"""
model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
Expand All @@ -93,6 +95,7 @@ def __init__(
use_cache=True,
pad_token_id=0,
eos_token_id=1,
gradient_checkpointing=False,
**kwargs
):
super().__init__(
Expand All @@ -116,6 +119,7 @@ def __init__(
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
self.gradient_checkpointing = gradient_checkpointing

@property
def hidden_size(self):
Expand Down
66 changes: 52 additions & 14 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -323,6 +324,7 @@ def __init__(self, config: T5Config, has_relative_attention_bias=False):
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)

def prune_heads(self, heads):
if len(heads) == 0:
Expand Down Expand Up @@ -485,6 +487,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.training and self.gradient_checkpointing:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length)

Expand Down Expand Up @@ -691,7 +695,11 @@ def forward(

outputs = (hidden_states,)

outputs = outputs + (present_key_value_state,) + attention_outputs
if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs
else:
outputs = outputs + attention_outputs

return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)


Expand Down Expand Up @@ -947,21 +955,51 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))

return custom_forward

layer_outputs = checkpoint(
create_custom_forward(layer_module),
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)

# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2]

# We share the position biases between the layers - the first layer store them
Expand Down