From 36393edcc56d529a7b8416ccf57fd7dbe3e7636e Mon Sep 17 00:00:00 2001 From: luyug Date: Sun, 25 Oct 2020 14:44:20 -0400 Subject: [PATCH 1/2] Add mixed precision evaluation --- src/transformers/trainer.py | 6 +++++- src/transformers/training_args.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 400527a225c..b0e9415972d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1466,7 +1466,11 @@ def prediction_step( inputs = self._prepare_inputs(inputs) with torch.no_grad(): - outputs = model(**inputs) + if self.args.fp16_eval and _use_native_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) if has_labels: loss = outputs[0].mean().detach() logits = outputs[1:] diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b86a1cbc2b8..fe0ddc35f82 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -268,6 +268,10 @@ class TrainingArguments: default=False, metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"}, ) + fp16_eval: bool = field( + default=False, + metadata={"help": "Whether to use 16-bit (mixed) precision in evaluation"}, + ) fp16_opt_level: str = field( default="O1", metadata={ From 4b9042fd96ed8bd6cd313a0c7dcc0a893b51255b Mon Sep 17 00:00:00 2001 From: luyug Date: Sun, 25 Oct 2020 23:46:36 -0400 Subject: [PATCH 2/2] use original flag --- src/transformers/trainer.py | 2 +- src/transformers/training_args.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b0e9415972d..e1d1947ecbe 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1466,7 +1466,7 @@ def prediction_step( inputs = self._prepare_inputs(inputs) with torch.no_grad(): - if self.args.fp16_eval and _use_native_amp: + if self.args.fp16 and _use_native_amp: with autocast(): outputs = model(**inputs) else: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index fe0ddc35f82..b86a1cbc2b8 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -268,10 +268,6 @@ class TrainingArguments: default=False, metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"}, ) - fp16_eval: bool = field( - default=False, - metadata={"help": "Whether to use 16-bit (mixed) precision in evaluation"}, - ) fp16_opt_level: str = field( default="O1", metadata={