Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: fix loss training when no data available #3571

Merged
merged 4 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 40 additions & 19 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,21 @@
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
if self.has_e and "energy" in model_pred and "energy" in label:
find_energy = label.get("find_energy", 0.0)
pref_e = pref_e * find_energy
if not self.use_l1_all:
l2_ener_loss = torch.mean(
torch.square(model_pred["energy"] - label["energy"])
)
if not self.inference:
more_loss["l2_ener_loss"] = l2_ener_loss.detach()
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
loss += atom_norm * (pref_e * l2_ener_loss)
rmse_e = l2_ener_loss.sqrt() * atom_norm
more_loss["rmse_e"] = rmse_e.detach()
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
l1_ener_loss = F.l1_loss(
Expand All @@ -141,24 +147,31 @@
reduction="sum",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
reduction="mean",
).detach()
more_loss["mae_e"] = self.display_if_exist(

Check warning on line 150 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L150

Added line #L150 was not covered by tests
F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
reduction="mean",
).detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = (
torch.mean(torch.abs(model_pred["energy"] - label["energy"]))
* atom_norm
)
more_loss["mae_e"] = mae_e.detach()
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)

Check warning on line 164 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L164

Added line #L164 was not covered by tests
mae_e_all = torch.mean(
torch.abs(model_pred["energy"] - label["energy"])
)
more_loss["mae_e_all"] = mae_e_all.detach()
more_loss["mae_e_all"] = self.display_if_exist(

Check warning on line 168 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L168

Added line #L168 was not covered by tests
mae_e_all.detach(), find_energy
)

if self.has_f and "force" in model_pred and "force" in label:
find_force = label.get("find_force", 0.0)
pref_f = pref_f * find_force
if "force_target_mask" in model_pred:
force_target_mask = model_pred["force_target_mask"]
else:
Expand All @@ -174,40 +187,48 @@
diff_f = label["force"] - model_pred["force"]
l2_force_loss = torch.mean(torch.square(diff_f))
if not self.inference:
more_loss["l2_force_loss"] = l2_force_loss.detach()
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss.detach(), find_force
)
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_f = l2_force_loss.sqrt()
more_loss["rmse_f"] = rmse_f.detach()
more_loss["rmse_f"] = self.display_if_exist(rmse_f.detach(), find_force)
else:
l1_force_loss = F.l1_loss(
label["force"], model_pred["force"], reduction="none"
)
if force_target_mask is not None:
l1_force_loss *= force_target_mask
force_cnt = force_target_mask.squeeze(-1).sum(-1)
more_loss["mae_f"] = (
l1_force_loss.mean(-1).sum(-1) / force_cnt
).mean()
more_loss["mae_f"] = self.display_if_exist(

Check warning on line 203 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L203

Added line #L203 was not covered by tests
(l1_force_loss.mean(-1).sum(-1) / force_cnt).mean(), find_force
)
l1_force_loss = (l1_force_loss.sum(-1).sum(-1) / force_cnt).sum()
else:
more_loss["mae_f"] = l1_force_loss.mean().detach()
more_loss["mae_f"] = self.display_if_exist(

Check warning on line 208 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L208

Added line #L208 was not covered by tests
l1_force_loss.mean().detach(), find_force
)
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if mae:
mae_f = torch.mean(torch.abs(diff_f))
more_loss["mae_f"] = mae_f.detach()
more_loss["mae_f"] = self.display_if_exist(mae_f.detach(), find_force)

Check warning on line 215 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L215

Added line #L215 was not covered by tests

if self.has_v and "virial" in model_pred and "virial" in label:
find_virial = label.get("find_virial", 0.0)
pref_v = pref_v * find_virial
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = l2_virial_loss.detach()
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = rmse_v.detach()
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = mae_v.detach()
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)

Check warning on line 231 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L231

Added line #L231 was not covered by tests
if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return model_pred, loss, more_loss
Expand Down
65 changes: 48 additions & 17 deletions deepmd/pt/loss/ener_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,21 @@
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
if self.has_e and "energy" in model_pred and "energy" in label:
find_energy = label.get("find_energy", 0.0)
pref_e = pref_e * find_energy
if not self.use_l1_all:
l2_ener_loss = torch.mean(
torch.square(model_pred["energy"] - label["energy"])
)
if not self.inference:
more_loss["l2_ener_loss"] = l2_ener_loss.detach()
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
loss += atom_norm * (pref_e * l2_ener_loss)
rmse_e = l2_ener_loss.sqrt() * atom_norm
more_loss["rmse_e"] = rmse_e.detach()
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
l1_ener_loss = F.l1_loss(
Expand All @@ -115,44 +121,61 @@
reduction="sum",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
reduction="mean",
).detach()
more_loss["mae_e"] = self.display_if_exist(

Check warning on line 124 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L124

Added line #L124 was not covered by tests
F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
reduction="mean",
).detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = (
torch.mean(torch.abs(model_pred["energy"] - label["energy"]))
* atom_norm
)
more_loss["mae_e"] = mae_e.detach()
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)

Check warning on line 138 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L138

Added line #L138 was not covered by tests
mae_e_all = torch.mean(
torch.abs(model_pred["energy"] - label["energy"])
)
more_loss["mae_e_all"] = mae_e_all.detach()
more_loss["mae_e_all"] = self.display_if_exist(

Check warning on line 142 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L142

Added line #L142 was not covered by tests
mae_e_all.detach(), find_energy
)

if self.has_fr and "force" in model_pred and "force" in label:
find_force_r = label.get("find_force", 0.0)
pref_fr = pref_fr * find_force_r
if not self.use_l1_all:
diff_fr = label["force"] - model_pred["force"]
l2_force_real_loss = torch.mean(torch.square(diff_fr))
if not self.inference:
more_loss["l2_force_r_loss"] = l2_force_real_loss.detach()
more_loss["l2_force_r_loss"] = self.display_if_exist(
l2_force_real_loss.detach(), find_force_r
)
loss += (pref_fr * l2_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_fr = l2_force_real_loss.sqrt()
more_loss["rmse_fr"] = rmse_fr.detach()
more_loss["rmse_fr"] = self.display_if_exist(
rmse_fr.detach(), find_force_r
)
if mae:
mae_fr = torch.mean(torch.abs(diff_fr))
more_loss["mae_fr"] = mae_fr.detach()
more_loss["mae_fr"] = self.display_if_exist(

Check warning on line 163 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L163

Added line #L163 was not covered by tests
mae_fr.detach(), find_force_r
)
else:
l1_force_real_loss = F.l1_loss(
label["force"], model_pred["force"], reduction="none"
)
more_loss["mae_fr"] = l1_force_real_loss.mean().detach()
more_loss["mae_fr"] = self.display_if_exist(

Check warning on line 170 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L170

Added line #L170 was not covered by tests
l1_force_real_loss.mean().detach(), find_force_r
)
l1_force_real_loss = l1_force_real_loss.sum(-1).mean(-1).sum()
loss += (pref_fr * l1_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION)

if self.has_fm and "force_mag" in model_pred and "force_mag" in label:
find_force_m = label.get("find_force_mag", 0.0)
pref_fm = pref_fm * find_force_m
nframes = model_pred["force_mag"].shape[0]
atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3])
label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3)
Expand All @@ -163,18 +186,26 @@
diff_fm = label_force_mag - model_pred_force_mag
l2_force_mag_loss = torch.mean(torch.square(diff_fm))
if not self.inference:
more_loss["l2_force_m_loss"] = l2_force_mag_loss.detach()
more_loss["l2_force_m_loss"] = self.display_if_exist(
l2_force_mag_loss.detach(), find_force_m
)
loss += (pref_fm * l2_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_fm = l2_force_mag_loss.sqrt()
more_loss["rmse_fm"] = rmse_fm.detach()
more_loss["rmse_fm"] = self.display_if_exist(
rmse_fm.detach(), find_force_m
)
if mae:
mae_fm = torch.mean(torch.abs(diff_fm))
more_loss["mae_fm"] = mae_fm.detach()
more_loss["mae_fm"] = self.display_if_exist(

Check warning on line 199 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L199

Added line #L199 was not covered by tests
mae_fm.detach(), find_force_m
)
else:
l1_force_mag_loss = F.l1_loss(
label_force_mag, model_pred_force_mag, reduction="none"
)
more_loss["mae_fm"] = l1_force_mag_loss.mean().detach()
more_loss["mae_fm"] = self.display_if_exist(

Check warning on line 206 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L206

Added line #L206 was not covered by tests
l1_force_mag_loss.mean().detach(), find_force_m
)
l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum()
loss += (pref_fm * l1_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION)

Expand Down
13 changes: 13 additions & 0 deletions deepmd/pt/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,16 @@ def forward(self, input_dict, model, label, natoms, learning_rate):
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
pass

@staticmethod
def display_if_exist(loss: torch.Tensor, find_property: float) -> torch.Tensor:
"""Display NaN if labeled property is not found.

Parameters
----------
loss : torch.Tensor
the loss tensor
find_property : float
whether the property is found
"""
return loss if bool(find_property) else torch.nan
24 changes: 17 additions & 7 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
and self.tensor_name in model_pred
and "atomic_" + self.label_name in label
):
find_local = label.get("find_" + "atomic_" + self.label_name, 0.0)
local_weight = self.local_weight * find_local
local_tensor_pred = model_pred[self.tensor_name].reshape(
[-1, natoms, self.tensor_size]
)
Expand All @@ -108,15 +110,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[f"l2_local_{self.tensor_name}_loss"] = l2_local_loss.detach()
loss += self.local_weight * l2_local_loss
more_loss[f"l2_local_{self.tensor_name}_loss"] = self.display_if_exist(
l2_local_loss.detach(), find_local
)
loss += local_weight * l2_local_loss
rmse_local = l2_local_loss.sqrt()
more_loss[f"rmse_local_{self.tensor_name}"] = rmse_local.detach()
more_loss[f"rmse_local_{self.tensor_name}"] = self.display_if_exist(
rmse_local.detach(), find_local
)
if (
self.has_global_weight
and "global_" + self.tensor_name in model_pred
and self.label_name in label
):
find_global = label.get("find_" + self.label_name, 0.0)
global_weight = self.global_weight * find_global
global_tensor_pred = model_pred["global_" + self.tensor_name].reshape(
[-1, self.tensor_size]
)
Expand All @@ -132,12 +140,14 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
atom_num = natoms
l2_global_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[f"l2_global_{self.tensor_name}_loss"] = (
l2_global_loss.detach()
more_loss[f"l2_global_{self.tensor_name}_loss"] = self.display_if_exist(
l2_global_loss.detach(), find_global
)
loss += self.global_weight * l2_global_loss
loss += global_weight * l2_global_loss
rmse_global = l2_global_loss.sqrt() / atom_num
more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach()
more_loss[f"rmse_global_{self.tensor_name}"] = self.display_if_exist(
rmse_global.detach(), find_global
)
return model_pred, loss, more_loss

@property
Expand Down
3 changes: 2 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ def get_data(self, is_train=True, task_key="Default"):
if item_key in input_keys:
input_dict[item_key] = batch_data[item_key]
else:
if item_key not in ["sid", "fid"] and "find_" not in item_key:
if item_key not in ["sid", "fid"]:
label_dict[item_key] = batch_data[item_key]
log_dict = {}
if "fid" in batch_data:
Expand Down Expand Up @@ -1052,6 +1052,7 @@ def print_header(self, fout, train_results, valid_results):
for k in sorted(train_results[model_key].keys()):
print_str += prop_fmt % (k + f"_trn_{model_key}")
print_str += " %8s\n" % "lr"
print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n"
fout.write(print_str)
fout.flush()

Expand Down
2 changes: 2 additions & 0 deletions source/tests/pt/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,9 @@ def test_consistency(self):
}
label = {
"energy": batch["energy"].to(env.DEVICE),
"find_energy": 1.0,
"force": batch["force"].to(env.DEVICE),
"find_force": 1.0,
}
cur_lr = my_lr.value(self.wanted_step)
model_predict, loss, _ = my_loss(
Expand Down
Loading