From cdbc6d52ff07fe41c9585fdae017279ea1a4cf6b Mon Sep 17 00:00:00 2001 From: Ceshine Lee Date: Wed, 7 Apr 2021 13:33:36 +0800 Subject: [PATCH] Implement gradient checkpoinging for T5Stack --- src/transformers/models/t5/modeling_t5.py | 60 ++++++++++++++++++----- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 2c8463d44edb..6472542a9fa5 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...file_utils import ( @@ -689,6 +690,10 @@ def forward( outputs = (hidden_states,) + if present_key_value_state is None: + # for compatibility with gradient checkpointing + present_key_value_state = torch.tensor(-1.) + present_key_value_state.requires_grad = True outputs = outputs + (present_key_value_state,) + attention_outputs return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) @@ -945,23 +950,52 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - encoder_layer_head_mask=encoder_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) + if getattr(self.config, "gradient_checkpointing", False) and self.training and (not self.config.is_decoder): + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + encoder_layer_head_mask, + past_key_value, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + encoder_layer_head_mask=encoder_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] + if isinstance(present_key_value_state, torch.Tensor) and present_key_value_state.item() == -1: + present_key_value_state = None + # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), # (self-attention position bias), (cross-attention weights), (cross-attention position bias)