Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix distributed loss (#5381)
Browse files Browse the repository at this point in the history
* fix distributed loss

* remove extra args
  • Loading branch information
AkshitaB authored Aug 27, 2021
1 parent 6355f07 commit b41cb3e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Multitask models now support `TextFieldTensor` in heads, not just in the backbone.
- Fixed the signature of `ScaledDotProductAttention` to match the other `Attention` classes
- Fixed the way names are applied to Tango `Step` instances.
- Fixed a bug in calculating loss in the distributed setting.

### Changed

Expand Down
21 changes: 12 additions & 9 deletions allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
from allennlp.models.model import Model
from allennlp.nn.parallel import DdpAccelerator, DdpWrappedModel, TorchDdpAccelerator
from allennlp.nn.util import dist_reduce_sum
from allennlp.training.callbacks import ConsoleLoggerCallback
from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback
from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback
Expand Down Expand Up @@ -544,8 +545,6 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
batch_loss,
batch_reg_loss,
self._batches_in_epoch_completed,
world_size=self._world_size,
cuda_device=self.cuda_device,
)

for callback in self._callbacks:
Expand Down Expand Up @@ -591,16 +590,19 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
):
metrics = {}
else:
train_loss = dist_reduce_sum(train_loss)
num_batches = dist_reduce_sum(self._batches_in_epoch_completed)
if train_reg_loss is not None:
train_reg_loss = dist_reduce_sum(train_reg_loss)

metrics = training_util.get_metrics(
self.model,
train_loss,
train_reg_loss,
batch_loss=None,
batch_reg_loss=None,
num_batches=self._batches_in_epoch_completed,
num_batches=num_batches,
reset=True,
world_size=self._world_size,
cuda_device=self.cuda_device,
)

for (worker, memory) in cpu_memory_usage:
Expand Down Expand Up @@ -688,8 +690,6 @@ def _validation_loss(self, epoch: int) -> Tuple[float, Optional[float], int]:
val_batch_loss,
val_batch_reg_loss,
batches_this_epoch,
world_size=self._world_size,
cuda_device=self.cuda_device,
)

description = training_util.description_from_metrics(val_metrics)
Expand Down Expand Up @@ -803,6 +803,11 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
if self._distributed:
dist.barrier()

val_loss = dist_reduce_sum(val_loss)
num_batches = dist_reduce_sum(num_batches)
if val_reg_loss is not None:
val_reg_loss = dist_reduce_sum(val_reg_loss)

val_metrics = training_util.get_metrics(
self.model,
val_loss,
Expand All @@ -811,8 +816,6 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
batch_reg_loss=None,
num_batches=num_batches,
reset=True,
world_size=self._world_size,
cuda_device=self.cuda_device,
)

# Check validation metric for early stopping
Expand Down
2 changes: 0 additions & 2 deletions allennlp/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ def get_metrics(
batch_reg_loss: Optional[float],
num_batches: int,
reset: bool = False,
world_size: int = 1,
cuda_device: Union[int, torch.device] = torch.device("cpu"),
) -> Dict[str, float]:
"""
Gets the metrics but sets `"loss"` to
Expand Down

0 comments on commit b41cb3e

Please sign in to comment.