Skip to content

Commit

Permalink
updating gather function with gather_for_metrics in run_wav2vec2_pret…
Browse files Browse the repository at this point in the history
…raining (huggingface#18877)

Co-authored-by: Arun Rajaram <arunrajaram@Aruns-MacBook-Pro.local>
  • Loading branch information
2 people authored and oneraghavan committed Sep 26, 2022
1 parent 72e0395 commit 89e027c
Showing 1 changed file with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def prepare_dataset(batch):
# make sure that `num_losses` is summed for distributed training
# and average gradients over losses of all devices
if accelerator.state.num_processes > 1:
num_losses = accelerator.gather(num_losses).sum()
num_losses = accelerator.gather_for_metrics(num_losses).sum()
gradient_multiplier = accelerator.state.num_processes / num_losses
multiply_grads(model.module.parameters(), gradient_multiplier)
else:
Expand Down Expand Up @@ -647,10 +647,10 @@ def prepare_dataset(batch):
outputs.diversity_loss.detach()

if accelerator.state.num_processes > 1:
loss = accelerator.gather(loss).sum()
outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum()
outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum()
percent_masked = accelerator.gather(percent_masked).sum()
loss = accelerator.gather_for_metrics(loss).sum()
outputs.contrastive_loss = accelerator.gather_for_metrics(outputs.contrastive_loss).sum()
outputs.diversity_loss = accelerator.gather_for_metrics(outputs.diversity_loss).sum()
percent_masked = accelerator.gather_for_metrics(percent_masked).sum()

train_logs = {
"loss": (loss * args.gradient_accumulation_steps) / num_losses,
Expand Down Expand Up @@ -713,7 +713,7 @@ def prepare_dataset(batch):

# sum over devices in multi-processing
if accelerator.num_processes > 1:
val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()}
val_logs = {k: accelerator.gather_for_metrics(v).sum() for k, v in val_logs.items()}

val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}

Expand Down

0 comments on commit 89e027c

Please sign in to comment.