From 7edc31201f5e78e73e31f490ad134c1f9888040f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 24 Jan 2024 00:48:31 -0500 Subject: [PATCH 1/2] add get_ntypes and get_type_map methods to DeepPot Signed-off-by: Jinzhe Zeng --- deepmd_pt/infer/deep_eval.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepmd_pt/infer/deep_eval.py b/deepmd_pt/infer/deep_eval.py index 08e73af..c21bc7c 100644 --- a/deepmd_pt/infer/deep_eval.py +++ b/deepmd_pt/infer/deep_eval.py @@ -9,6 +9,7 @@ from typing import Callable, Optional, Tuple, Union, List from deepmd_pt.utils import env from deepmd_pt.utils.auto_batch_size import AutoBatchSize +from deepmd_utils.infer.deep_pot import DeepPot as DeepPotBase class DeepEval: @@ -53,7 +54,7 @@ def eval( raise NotImplementedError -class DeepPot(DeepEval): +class DeepPot(DeepEval, DeepPotBase): def __init__( self, model_file: "Path", @@ -177,6 +178,14 @@ def _eval_model( else: return energy_out, force_out, virial_out, atomic_energy_out, atomic_virial_out + def get_ntypes(self) -> int: + """Get the number of atom types of this model.""" + return len(self.type_map) + + def get_type_map(self) -> List[str]: + """Get the type map (element name of the atom types) of this model.""" + return self.type_map + # For tests only def eval_model( From 465f57d5fbf5202f9961fbcdbe6e0e5dc7d3c919 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 24 Jan 2024 00:55:52 -0500 Subject: [PATCH 2/2] add get_dim_fparam and get_dim_aparam Signed-off-by: Jinzhe Zeng --- deepmd_pt/infer/deep_eval.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/deepmd_pt/infer/deep_eval.py b/deepmd_pt/infer/deep_eval.py index c21bc7c..878f8a6 100644 --- a/deepmd_pt/infer/deep_eval.py +++ b/deepmd_pt/infer/deep_eval.py @@ -186,6 +186,14 @@ def get_type_map(self) -> List[str]: """Get the type map (element name of the atom types) of this model.""" return self.type_map + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this DP.""" + return 0 + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this DP.""" + return 0 + # For tests only def eval_model(