From 93c1f391d33d67733add30813def53134db49594 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 26 Aug 2021 19:35:33 -0700 Subject: [PATCH 1/2] fix distributed loss --- allennlp/training/gradient_descent_trainer.py | 4 ---- allennlp/training/util.py | 6 +++++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/allennlp/training/gradient_descent_trainer.py b/allennlp/training/gradient_descent_trainer.py index ca35e20aaae..f84c513d260 100644 --- a/allennlp/training/gradient_descent_trainer.py +++ b/allennlp/training/gradient_descent_trainer.py @@ -545,7 +545,6 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: batch_reg_loss, self._batches_in_epoch_completed, world_size=self._world_size, - cuda_device=self.cuda_device, ) for callback in self._callbacks: @@ -600,7 +599,6 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: num_batches=self._batches_in_epoch_completed, reset=True, world_size=self._world_size, - cuda_device=self.cuda_device, ) for (worker, memory) in cpu_memory_usage: @@ -689,7 +687,6 @@ def _validation_loss(self, epoch: int) -> Tuple[float, Optional[float], int]: val_batch_reg_loss, batches_this_epoch, world_size=self._world_size, - cuda_device=self.cuda_device, ) description = training_util.description_from_metrics(val_metrics) @@ -812,7 +809,6 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]: num_batches=num_batches, reset=True, world_size=self._world_size, - cuda_device=self.cuda_device, ) # Check validation metric for early stopping diff --git a/allennlp/training/util.py b/allennlp/training/util.py index 6b19e0a3fcf..4d0cec62788 100644 --- a/allennlp/training/util.py +++ b/allennlp/training/util.py @@ -262,7 +262,6 @@ def get_metrics( num_batches: int, reset: bool = False, world_size: int = 1, - cuda_device: Union[int, torch.device] = torch.device("cpu"), ) -> Dict[str, float]: """ Gets the metrics but sets `"loss"` to @@ -278,6 +277,11 @@ def get_metrics( if batch_reg_loss is not None: metrics["batch_reg_loss"] = batch_reg_loss metrics["reg_loss"] = float(total_reg_loss / num_batches) if num_batches > 0 else 0.0 + if world_size > 1: + for key in {"loss", "reg_less", "batch_loss", "batch_reg_loss"}: + if key not in metrics: + continue + metrics[key] = nn_util.dist_reduce_sum(metrics[key]) / world_size return metrics From 10d66474ef81b52f54be453b97455bdbbad33a2c Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Fri, 27 Aug 2021 09:17:34 +0530 Subject: [PATCH 2/2] Update util.py --- allennlp/training/util.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/allennlp/training/util.py b/allennlp/training/util.py index 4d0cec62788..d59b125fd50 100644 --- a/allennlp/training/util.py +++ b/allennlp/training/util.py @@ -269,6 +269,16 @@ def get_metrics( the `"loss"` metric is "average loss per batch". Returns the `"batch_loss"` separately. """ + if world_size > 1: + total_loss = nn_util.dist_reduce_sum(total_loss) + num_batches = nn_util.dist_reduce_sum(num_batches) + if total_reg_loss is not None: + total_reg_loss = nn_util.dist_reduce_sum(total_reg_loss) + if batch_loss is not None: + batch_loss = nn_util.dist_reduce_sum(batch_loss) + if batch_reg_loss is not None: + batch_reg_loss = nn_util.dist_reduce_sum(batch_reg_loss) + metrics = model.get_metrics(reset=reset) if batch_loss is not None: metrics["batch_loss"] = batch_loss @@ -277,11 +287,6 @@ def get_metrics( if batch_reg_loss is not None: metrics["batch_reg_loss"] = batch_reg_loss metrics["reg_loss"] = float(total_reg_loss / num_batches) if num_batches > 0 else 0.0 - if world_size > 1: - for key in {"loss", "reg_less", "batch_loss", "batch_reg_loss"}: - if key not in metrics: - continue - metrics[key] = nn_util.dist_reduce_sum(metrics[key]) / world_size return metrics