Skip to content

Commit

Permalink
Update no trainer scripts for language modeling and image classificat…
Browse files Browse the repository at this point in the history
…ion examples (#18443)

* Update no_trainer script for image-classification

* Update no_trainer scripts for language-modeling examples

* Remove unused variable

* Removing truncation from losses array for language modeling examples
  • Loading branch information
nandwalritik authored Aug 3, 2022
1 parent 10e1ec9 commit 3db4378
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -508,19 +508,11 @@ def collate_fn(examples):
break

model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
3 changes: 1 addition & 2 deletions examples/pytorch/language-modeling/run_clm_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,9 @@ def group_texts(examples):
outputs = model(**batch)

loss = outputs.loss
losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))

losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try:
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)
Expand Down
3 changes: 1 addition & 2 deletions examples/pytorch/language-modeling/run_mlm_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,10 +642,9 @@ def group_texts(examples):
outputs = model(**batch)

loss = outputs.loss
losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))

losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try:
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)
Expand Down

0 comments on commit 3db4378

Please sign in to comment.