Skip to content

Commit

Permalink
add torch.no_grad when in eval mode (huggingface#17020)
Browse files Browse the repository at this point in the history
* add torch.no_grad when in eval mode

* make style quality
  • Loading branch information
JunnYu authored May 2, 2022
1 parent 9586e22 commit bdd690a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ def collate_fn(examples):
model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
outputs = model(**batch)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ def preprocess_val(example_batch):
model.eval()
samples_seen = 0
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
outputs = model(**batch)
with torch.no_grad():
outputs = model(**batch)

upsampled_logits = torch.nn.functional.interpolate(
outputs.logits, size=batch["labels"].shape[-2:], mode="bilinear", align_corners=False
Expand Down
4 changes: 3 additions & 1 deletion examples/pytorch/text-classification/run_glue_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pathlib import Path

import datasets
import torch
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
Expand Down Expand Up @@ -514,7 +515,8 @@ def preprocess_function(examples):
model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
outputs = model(**batch)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import Optional, List

import datasets
import torch
from datasets import load_dataset

import transformers
Expand Down Expand Up @@ -871,7 +872,8 @@ def tokenize_function(examples):

model.eval()
for step, batch in enumerate(eval_dataloader):
outputs = model(**batch)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
metric.add_batch(
predictions=accelerator.gather(predictions),
Expand Down

0 comments on commit bdd690a

Please sign in to comment.