From 329f43dae1e55ba4552a9e4ce96ebe79ff1f439d Mon Sep 17 00:00:00 2001 From: Johannes Klicpera Date: Sat, 18 Sep 2021 04:50:35 +0200 Subject: [PATCH] Correctly calculate distributed loss average (#269) Co-authored-by: Abhishek Das --- ocpmodels/modules/loss.py | 24 ++++++++++++++++++++++++ ocpmodels/trainers/base_trainer.py | 4 +++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/ocpmodels/modules/loss.py b/ocpmodels/modules/loss.py index 8d05b2ab4a..3285b7afad 100644 --- a/ocpmodels/modules/loss.py +++ b/ocpmodels/modules/loss.py @@ -1,6 +1,8 @@ import torch from torch import nn +from ocpmodels.common import distutils + class L2MAELoss(nn.Module): def __init__(self, reduction="mean"): @@ -14,3 +16,25 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): return torch.mean(dists) elif self.reduction == "sum": return torch.sum(dists) + + +class DDPLoss(nn.Module): + def __init__(self, loss_fn, reduction="mean"): + super().__init__() + self.loss_fn = loss_fn + self.loss_fn.reduction = "sum" + self.reduction = reduction + assert reduction in ["mean", "sum"] + + def forward(self, input: torch.Tensor, target: torch.Tensor): + loss = self.loss_fn(input, target) + if self.reduction == "mean": + num_samples = input.shape[0] + num_samples = distutils.all_reduce( + num_samples, device=input.device + ) + # Multiply by world size since gradients are averaged + # across DDP replicas + return loss * distutils.get_world_size() / num_samples + else: + return loss diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index ff56cdcee7..fca27cd6b5 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -40,7 +40,7 @@ from ocpmodels.modules.exponential_moving_average import ( ExponentialMovingAverage, ) -from ocpmodels.modules.loss import L2MAELoss +from ocpmodels.modules.loss import DDPLoss, L2MAELoss from ocpmodels.modules.normalizer import Normalizer from ocpmodels.modules.scheduler import LRScheduler @@ -366,6 +366,8 @@ def load_loss(self): raise NotImplementedError( f"Unknown loss function name: {loss_name}" ) + if distutils.initialized(): + self.loss_fn[loss] = DDPLoss(self.loss_fn[loss]) def load_optimizer(self): optimizer = self.config["optim"].get("optimizer", "AdamW")