Skip to content

Commit

Permalink
add more methods to DeepPot (#175)
Browse files Browse the repository at this point in the history
* add get_ntypes and get_type_map methods to DeepPot

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>

* add get_dim_fparam and get_dim_aparam

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Jan 24, 2024
1 parent c3cf976 commit 05c3f9d
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion deepmd_pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, Optional, Tuple, Union, List
from deepmd_pt.utils import env
from deepmd_pt.utils.auto_batch_size import AutoBatchSize
from deepmd_utils.infer.deep_pot import DeepPot as DeepPotBase


class DeepEval:
Expand Down Expand Up @@ -53,7 +54,7 @@ def eval(
raise NotImplementedError


class DeepPot(DeepEval):
class DeepPot(DeepEval, DeepPotBase):
def __init__(
self,
model_file: "Path",
Expand Down Expand Up @@ -177,6 +178,22 @@ def _eval_model(
else:
return energy_out, force_out, virial_out, atomic_energy_out, atomic_virial_out

def get_ntypes(self) -> int:
"""Get the number of atom types of this model."""
return len(self.type_map)

def get_type_map(self) -> List[str]:
"""Get the type map (element name of the atom types) of this model."""
return self.type_map

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this DP."""
return 0

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this DP."""
return 0


# For tests only
def eval_model(
Expand Down

0 comments on commit 05c3f9d

Please sign in to comment.