Skip to content

Commit

Permalink
Implement gradient checkpoinging for T5Stack
Browse files Browse the repository at this point in the history
  • Loading branch information
ceshine committed Apr 7, 2021
1 parent fd338ab commit cdbc6d5
Showing 1 changed file with 47 additions and 13 deletions.
60 changes: 47 additions & 13 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -689,6 +690,10 @@ def forward(

outputs = (hidden_states,)

if present_key_value_state is None:
# for compatibility with gradient checkpointing
present_key_value_state = torch.tensor(-1.)
present_key_value_state.requires_grad = True
outputs = outputs + (present_key_value_state,) + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)

Expand Down Expand Up @@ -945,23 +950,52 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
encoder_layer_head_mask=encoder_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
if getattr(self.config, "gradient_checkpointing", False) and self.training and (not self.config.is_decoder):
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward

layer_outputs = checkpoint(
create_custom_forward(layer_module),
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
layer_head_mask,
encoder_layer_head_mask,
past_key_value,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
encoder_layer_head_mask=encoder_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]

if isinstance(present_key_value_state, torch.Tensor) and present_key_value_state.item() == -1:
present_key_value_state = None

# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights),
# (self-attention position bias), (cross-attention weights), (cross-attention position bias)
Expand Down

0 comments on commit cdbc6d5

Please sign in to comment.