Skip to content

Commit

Permalink
Add an error message that fires when Reformer is not in training mode…
Browse files Browse the repository at this point in the history
…, but one runs .backward() (#11117)
  • Loading branch information
forest1988 authored and Rocketknight1 committed Apr 21, 2021
1 parent f3d3e7d commit f6bbd3a
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/reformer/modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,10 @@ def backward_pass(
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
# This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py

assert (
self.training
), "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode."

with torch.enable_grad():
next_attn_output.requires_grad = True

Expand Down

0 comments on commit f6bbd3a

Please sign in to comment.