From 05c3f9d0079e7263b20a7c20c16122f93554a6fa Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 24 Jan 2024 09:01:02 -0500 Subject: [PATCH] add more methods to DeepPot (#175) * add get_ntypes and get_type_map methods to DeepPot Signed-off-by: Jinzhe Zeng * add get_dim_fparam and get_dim_aparam Signed-off-by: Jinzhe Zeng --------- Signed-off-by: Jinzhe Zeng --- deepmd_pt/infer/deep_eval.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/deepmd_pt/infer/deep_eval.py b/deepmd_pt/infer/deep_eval.py index 08e73af..878f8a6 100644 --- a/deepmd_pt/infer/deep_eval.py +++ b/deepmd_pt/infer/deep_eval.py @@ -9,6 +9,7 @@ from typing import Callable, Optional, Tuple, Union, List from deepmd_pt.utils import env from deepmd_pt.utils.auto_batch_size import AutoBatchSize +from deepmd_utils.infer.deep_pot import DeepPot as DeepPotBase class DeepEval: @@ -53,7 +54,7 @@ def eval( raise NotImplementedError -class DeepPot(DeepEval): +class DeepPot(DeepEval, DeepPotBase): def __init__( self, model_file: "Path", @@ -177,6 +178,22 @@ def _eval_model( else: return energy_out, force_out, virial_out, atomic_energy_out, atomic_virial_out + def get_ntypes(self) -> int: + """Get the number of atom types of this model.""" + return len(self.type_map) + + def get_type_map(self) -> List[str]: + """Get the type map (element name of the atom types) of this model.""" + return self.type_map + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this DP.""" + return 0 + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this DP.""" + return 0 + # For tests only def eval_model(