Skip to content

Commit

Permalink
Add an error message that fires when Reformer is not in training mode…
Browse files Browse the repository at this point in the history
…, but one runs .backward() (huggingface#11117)
  • Loading branch information
forest1988 authored and Iwontbecreative committed Jul 15, 2021
1 parent 3167835 commit 8e806a6
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/reformer/modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,10 @@ def backward_pass(
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
# This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py

assert (
self.training
), "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode."

with torch.enable_grad():
next_attn_output.requires_grad = True

Expand Down

0 comments on commit 8e806a6

Please sign in to comment.