diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 516fff8f91e3..28f0fdd08ed5 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -1512,6 +1512,10 @@ def backward_pass( # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py + assert ( + self.training + ), "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode." + with torch.enable_grad(): next_attn_output.requires_grad = True