diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 400527a225c..e1d1947ecbe 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1466,7 +1466,11 @@ def prediction_step( inputs = self._prepare_inputs(inputs) with torch.no_grad(): - outputs = model(**inputs) + if self.args.fp16 and _use_native_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) if has_labels: loss = outputs[0].mean().detach() logits = outputs[1:]