From 9f6e269cdc453fc771f9587284d93c82f07761ec Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Thu, 24 Oct 2024 15:19:06 +0800 Subject: [PATCH 1/2] support loss plugin for external package --- deepmd/pt/loss/loss.py | 5 ++++- deepmd/pt/train/training.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/loss/loss.py b/deepmd/pt/loss/loss.py index 1a091e074e..691646b279 100644 --- a/deepmd/pt/loss/loss.py +++ b/deepmd/pt/loss/loss.py @@ -9,9 +9,12 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.plugin import ( + make_plugin_registry, +) -class TaskLoss(torch.nn.Module, ABC): +class TaskLoss(torch.nn.Module, ABC, make_plugin_registry("loss")): def __init__(self, **kwargs): """Construct loss.""" super().__init__() diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 0f7c030a84..239e0c035b 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -28,6 +28,7 @@ EnergySpinLoss, EnergyStdLoss, PropertyLoss, + TaskLoss, TensorLoss, ) from deepmd.pt.model.model import ( @@ -1258,7 +1259,8 @@ def get_loss(loss_params, start_lr, _ntypes, _model): loss_params["task_dim"] = task_dim return PropertyLoss(**loss_params) else: - raise NotImplementedError + loss_params["starter_learning_rate"] = start_lr + return TaskLoss.get_class_by_type(loss_type).get_loss(loss_params) def get_single_model( From 4bcf36db5fdf0a1178788981eddecf9f8167dde3 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Fri, 25 Oct 2024 09:55:43 +0800 Subject: [PATCH 2/2] resolve comment: define a standard input and output for `TaskLoss.get_loss` method --- deepmd/pt/loss/loss.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/deepmd/pt/loss/loss.py b/deepmd/pt/loss/loss.py index 691646b279..5447c8735b 100644 --- a/deepmd/pt/loss/loss.py +++ b/deepmd/pt/loss/loss.py @@ -41,3 +41,23 @@ def display_if_exist(loss: torch.Tensor, find_property: float) -> torch.Tensor: whether the property is found """ return loss if bool(find_property) else torch.nan + + @classmethod + def get_loss(cls, loss_params: dict) -> "TaskLoss": + """Get the loss module by the parameters. + + By default, all the parameters are directly passed to the constructor. + If not, override this method. + + Parameters + ---------- + loss_params : dict + The loss parameters + + Returns + ------- + TaskLoss + The loss module + """ + loss = cls(**loss_params) + return loss