Skip to content

Commit

Permalink
Fix gradient checkpointing bug in trocr (huggingface#22126)
Browse files Browse the repository at this point in the history
* Fix gradient checkpointing bug in trocr

* Fix format

* Update src/transformers/models/trocr/modeling_trocr.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
2 people authored and novice03 committed Jun 23, 2023
1 parent 8efde69 commit fe14aad
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/transformers/models/trocr/modeling_trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,13 @@ def forward(
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
)
use_cache = False

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand All @@ -689,12 +696,6 @@ def forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache ="
" False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down

0 comments on commit fe14aad

Please sign in to comment.