From 14d9364b41cbdd2389a64e02903fc71c88075f2e Mon Sep 17 00:00:00 2001
From: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Date: Tue, 28 May 2024 18:01:32 -0400
Subject: [PATCH] fix: fix DeepGlobalPolar and DeepWFC initlization

Fix #3561. Fix #3562.

Not sure if some one uses them, but it's good to keep compatibility.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
---
 deepmd/infer/deep_eval.py             |  3 ++
 deepmd/infer/deep_polar.py            | 27 +++++++++++++++++-
 deepmd/infer/deep_tensor.py           | 21 ++++++++++++++
 deepmd/infer/deep_wfc.py              | 28 ++++++++++++++++--
 source/tests/tf/test_get_potential.py | 41 ++++++++++++++++-----------
 5 files changed, 101 insertions(+), 19 deletions(-)

diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py
index 5a00ba616d..879455b942 100644
--- a/deepmd/infer/deep_eval.py
+++ b/deepmd/infer/deep_eval.py
@@ -76,6 +76,9 @@ class DeepEvalBackend(ABC):
         "dos_redu": "dos",
         "mask_mag": "mask_mag",
         "mask": "mask",
+        # old models in v1
+        "global_polar": "global_polar",
+        "wfc": "wfc",
     }
 
     @abstractmethod
diff --git a/deepmd/infer/deep_polar.py b/deepmd/infer/deep_polar.py
index c2089b278d..6650c349a2 100644
--- a/deepmd/infer/deep_polar.py
+++ b/deepmd/infer/deep_polar.py
@@ -7,8 +7,14 @@
 
 import numpy as np
 
+from deepmd.dpmodel.output_def import (
+    FittingOutputDef,
+    ModelOutputDef,
+    OutputVariableDef,
+)
 from deepmd.infer.deep_tensor import (
     DeepTensor,
+    OldDeepTensor,
 )
 
 
@@ -36,7 +42,7 @@ def output_tensor_name(self) -> str:
         return "polar"
 
 
-class DeepGlobalPolar(DeepTensor):
+class DeepGlobalPolar(OldDeepTensor):
     @property
     def output_tensor_name(self) -> str:
         return "global_polar"
@@ -95,3 +101,22 @@ def eval(
             mixed_type=mixed_type,
             **kwargs,
         )
+
+    @property
+    def output_def(self) -> ModelOutputDef:
+        """Get the output definition of this model."""
+        # no atomic or differentiable output is defined
+        return ModelOutputDef(
+            FittingOutputDef(
+                [
+                    OutputVariableDef(
+                        self.output_tensor_name,
+                        shape=[-1],
+                        reduciable=False,
+                        r_differentiable=False,
+                        c_differentiable=False,
+                        atomic=False,
+                    ),
+                ]
+            )
+        )
diff --git a/deepmd/infer/deep_tensor.py b/deepmd/infer/deep_tensor.py
index 14e13e7f84..106bc3156c 100644
--- a/deepmd/infer/deep_tensor.py
+++ b/deepmd/infer/deep_tensor.py
@@ -234,3 +234,24 @@ def output_def(self) -> ModelOutputDef:
                 ]
             )
         )
+
+
+class OldDeepTensor(DeepTensor):
+    """Old tensor models from v1, which has no gradient output."""
+
+    # See https://github.com/deepmodeling/deepmd-kit/blob/1d1b251a2c5f05d1401aa89be792f9ed18b8f096/source/train/Model.py#L264
+    def eval_full(
+        self,
+        coords: np.ndarray,
+        cells: Optional[np.ndarray],
+        atom_types: np.ndarray,
+        atomic: bool = False,
+        fparam: Optional[np.ndarray] = None,
+        aparam: Optional[np.ndarray] = None,
+        mixed_type: bool = False,
+        **kwargs: dict,
+    ) -> Tuple[np.ndarray, ...]:
+        """Unsupported method."""
+        raise RuntimeError(
+            "This model does not support eval_full method. Use eval instead."
+        )
diff --git a/deepmd/infer/deep_wfc.py b/deepmd/infer/deep_wfc.py
index deed938e04..d92af28f5a 100644
--- a/deepmd/infer/deep_wfc.py
+++ b/deepmd/infer/deep_wfc.py
@@ -1,10 +1,15 @@
 # SPDX-License-Identifier: LGPL-3.0-or-later
+from deepmd.dpmodel.output_def import (
+    FittingOutputDef,
+    ModelOutputDef,
+    OutputVariableDef,
+)
 from deepmd.infer.deep_tensor import (
-    DeepTensor,
+    OldDeepTensor,
 )
 
 
-class DeepWFC(DeepTensor):
+class DeepWFC(OldDeepTensor):
     """Deep WFC model.
 
     Parameters
@@ -26,3 +31,22 @@ class DeepWFC(DeepTensor):
     @property
     def output_tensor_name(self) -> str:
         return "wfc"
+
+    @property
+    def output_def(self) -> ModelOutputDef:
+        """Get the output definition of this model."""
+        # no reduciable or differentiable output is defined
+        return ModelOutputDef(
+            FittingOutputDef(
+                [
+                    OutputVariableDef(
+                        self.output_tensor_name,
+                        shape=[-1],
+                        reduciable=False,
+                        r_differentiable=False,
+                        c_differentiable=False,
+                        atomic=True,
+                    ),
+                ]
+            )
+        )
diff --git a/source/tests/tf/test_get_potential.py b/source/tests/tf/test_get_potential.py
index 47462a20a3..fb39d41d2e 100644
--- a/source/tests/tf/test_get_potential.py
+++ b/source/tests/tf/test_get_potential.py
@@ -1,8 +1,15 @@
 # SPDX-License-Identifier: LGPL-3.0-or-later
 """Test if `DeepPotential` facto function returns the right type of potential."""
 
+import tempfile
 import unittest
 
+from deepmd.infer.deep_polar import (
+    DeepGlobalPolar,
+)
+from deepmd.infer.deep_wfc import (
+    DeepWFC,
+)
 from deepmd.tf.infer import (
     DeepDipole,
     DeepPolar,
@@ -35,16 +42,19 @@ def setUp(self):
             str(self.work_dir / "deeppolar.pbtxt"), str(self.work_dir / "deep_polar.pb")
         )
 
-        # TODO add model files for globalpolar and WFC
-        # convert_pbtxt_to_pb(
-        #     str(self.work_dir / "deepglobalpolar.pbtxt"),
-        #     str(self.work_dir / "deep_globalpolar.pb")
-        # )
+        with open(self.work_dir / "deeppolar.pbtxt") as f:
+            deeppolar_pbtxt = f.read()
 
-        # convert_pbtxt_to_pb(
-        #     str(self.work_dir / "deepwfc.pbtxt"),
-        #     str(self.work_dir / "deep_wfc.pb")
-        # )
+        # not an actual globalpolar and wfc model, but still good enough for testing factory
+        with tempfile.NamedTemporaryFile(mode="w") as f:
+            f.write(deeppolar_pbtxt.replace("polar", "global_polar"))
+            f.flush()
+            convert_pbtxt_to_pb(f.name, str(self.work_dir / "deep_globalpolar.pb"))
+
+        with tempfile.NamedTemporaryFile(mode="w") as f:
+            f.write(deeppolar_pbtxt.replace("polar", "wfc"))
+            f.flush()
+            convert_pbtxt_to_pb(f.name, str(self.work_dir / "deep_wfc.pb"))
 
     def tearDown(self):
         for f in self.work_dir.glob("*.pb"):
@@ -62,11 +72,10 @@ def test_factory(self):
         dp = DeepPotential(self.work_dir / "deep_pot.pb")
         self.assertIsInstance(dp, DeepPot, msg.format(DeepPot, type(dp)))
 
-        # TODO add model files for globalpolar and WFC
-        # dp = DeepPotential(self.work_dir / "deep_globalpolar.pb")
-        # self.assertIsInstance(
-        #     dp, DeepGlobalPolar, msg.format(DeepGlobalPolar, type(dp))
-        # )
+        dp = DeepPotential(self.work_dir / "deep_globalpolar.pb")
+        self.assertIsInstance(
+            dp, DeepGlobalPolar, msg.format(DeepGlobalPolar, type(dp))
+        )
 
-        # dp = DeepPotential(self.work_dir / "deep_wfc.pb")
-        # self.assertIsInstance(dp, DeepWFC, msg.format(DeepWFC, type(dp)))
+        dp = DeepPotential(self.work_dir / "deep_wfc.pb")
+        self.assertIsInstance(dp, DeepWFC, msg.format(DeepWFC, type(dp)))