diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py index f2658db9e68c..b1c1848aa313 100644 --- a/examples/pytorch/text-classification/run_glue_no_trainer.py +++ b/examples/pytorch/text-classification/run_glue_no_trainer.py @@ -404,7 +404,7 @@ def preprocess_function(examples): model.eval() for step, batch in enumerate(eval_dataloader): outputs = model(**batch) - predictions = outputs.logits.argmax(dim=-1) + predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() metric.add_batch( predictions=accelerator.gather(predictions), references=accelerator.gather(batch["labels"]),