Skip to content

Commit

Permalink
Make gradient checkpointing work with the decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ceshine committed Apr 24, 2021
1 parent 81254e6 commit acaeee6
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length)

Expand Down Expand Up @@ -955,7 +957,6 @@ def forward(
if (
getattr(self.config, "gradient_checkpointing", False)
and self.training
and (not self.config.is_decoder)
):
if use_cache:
logger.warn(
Expand All @@ -980,7 +981,7 @@ def custom_forward(*inputs):
encoder_decoder_position_bias,
layer_head_mask,
cross_attn_layer_head_mask,
past_key_value,
None # past_key_value is always None with gradient checkpointing
)
else:
layer_outputs = layer_module(
Expand Down

0 comments on commit acaeee6

Please sign in to comment.