Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

fix distributed loss #5380

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,6 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
batch_reg_loss,
self._batches_in_epoch_completed,
world_size=self._world_size,
cuda_device=self.cuda_device,
)

for callback in self._callbacks:
Expand Down Expand Up @@ -600,7 +599,6 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
num_batches=self._batches_in_epoch_completed,
reset=True,
world_size=self._world_size,
cuda_device=self.cuda_device,
)

for (worker, memory) in cpu_memory_usage:
Expand Down Expand Up @@ -689,7 +687,6 @@ def _validation_loss(self, epoch: int) -> Tuple[float, Optional[float], int]:
val_batch_reg_loss,
batches_this_epoch,
world_size=self._world_size,
cuda_device=self.cuda_device,
)

description = training_util.description_from_metrics(val_metrics)
Expand Down Expand Up @@ -812,7 +809,6 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
num_batches=num_batches,
reset=True,
world_size=self._world_size,
cuda_device=self.cuda_device,
)

# Check validation metric for early stopping
Expand Down
11 changes: 10 additions & 1 deletion allennlp/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,23 @@ def get_metrics(
num_batches: int,
reset: bool = False,
world_size: int = 1,
cuda_device: Union[int, torch.device] = torch.device("cpu"),
) -> Dict[str, float]:
"""
Gets the metrics but sets `"loss"` to
the total loss divided by the `num_batches` so that
the `"loss"` metric is "average loss per batch".
Returns the `"batch_loss"` separately.
"""
if world_size > 1:
total_loss = nn_util.dist_reduce_sum(total_loss)
num_batches = nn_util.dist_reduce_sum(num_batches)
if total_reg_loss is not None:
total_reg_loss = nn_util.dist_reduce_sum(total_reg_loss)
if batch_loss is not None:
batch_loss = nn_util.dist_reduce_sum(batch_loss)
if batch_reg_loss is not None:
batch_reg_loss = nn_util.dist_reduce_sum(batch_reg_loss)

metrics = model.get_metrics(reset=reset)
if batch_loss is not None:
metrics["batch_loss"] = batch_loss
Expand Down