Skip to content

Commit

Permalink
Support for using tag weights with L2MAE force loss (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhshkdz authored Jan 19, 2023
1 parent 71a56ac commit 255c9ef
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions ocpmodels/trainers/forces_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,21 +489,28 @@ def _compute_loss(self, out, batch_list):
weight[batch_tags == 1] = tag_specific_weights[1]
weight[batch_tags == 2] = tag_specific_weights[2]

loss_force_list = torch.abs(out["forces"] - force_target)
train_loss_force_unnormalized = torch.sum(
loss_force_list * weight.view(-1, 1)
)
train_loss_force_normalizer = 3.0 * weight.sum()

# add up normalizer to obtain global normalizer
distutils.all_reduce(train_loss_force_normalizer)
if self.config["optim"].get("loss_force", "l2mae") == "l2mae":
dists = torch.norm(
out["forces"] - force_target, p=2, dim=-1
)
weighted_dists_sum = (dists * weight).sum()

# perform loss normalization before backprop
train_loss_force_normalized = train_loss_force_unnormalized * (
distutils.get_world_size() / train_loss_force_normalizer
)
loss.append(train_loss_force_normalized)
num_samples = out["forces"].shape[0]
num_samples = distutils.all_reduce(
num_samples, device=self.device
)
weighted_dists_sum = (
weighted_dists_sum
* distutils.get_world_size()
/ num_samples
)

force_mult = self.config["optim"].get(
"force_coefficient", 30
)
loss.append(force_mult * weighted_dists_sum)
else:
raise NotImplementedError
else:
# Force coefficient = 30 has been working well for us.
force_mult = self.config["optim"].get("force_coefficient", 30)
Expand Down

0 comments on commit 255c9ef

Please sign in to comment.