Skip to content

Commit

Permalink
pt: support loading frozen models in DeepEval (#3253)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Feb 13, 2024
1 parent 1bdc60d commit 392b9e0
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
4 changes: 2 additions & 2 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def main_parser() -> argparse.ArgumentParser:
"--model",
default="frozen_model",
type=str,
help="Frozen model file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pt",
help="Frozen model file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth.",
)
parser_tst_subgroup = parser_tst.add_mutually_exclusive_group()
parser_tst_subgroup.add_argument(
Expand Down Expand Up @@ -512,7 +512,7 @@ def main_parser() -> argparse.ArgumentParser:
default=["graph.000", "graph.001", "graph.002", "graph.003"],
nargs="+",
type=str,
help="Frozen models file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pt.",
help="Frozen models file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth.",
)
parser_model_devi.add_argument(
"-s",
Expand Down
10 changes: 7 additions & 3 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,11 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if FLAGS.command == "train":
train(FLAGS)
elif FLAGS.command == "test":
dict_args["output"] = str(Path(FLAGS.model).with_suffix(".pt"))
dict_args["output"] = (
str(Path(FLAGS.model).with_suffix(".pth"))
if Path(FLAGS.model).suffix not in (".pt", ".pth")
else FLAGS.model
)
test(**dict_args)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
Expand All @@ -316,8 +320,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
doc_train_input(**dict_args)
elif FLAGS.command == "model-devi":
dict_args["models"] = [
str(Path(mm).with_suffix(".pt"))
if Path(mm).suffix not in (".pb", ".pt")
str(Path(mm).with_suffix(".pth"))
if Path(mm).suffix not in (".pb", ".pt", ".pth")
else mm
for mm in dict_args["models"]
]
Expand Down
32 changes: 19 additions & 13 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,25 @@ def __init__(
):
self.output_def = output_def
self.model_path = model_file
state_dict = torch.load(model_file, map_location=env.DEVICE)
if "model" in state_dict:
state_dict = state_dict["model"]
self.input_param = state_dict["_extra_state"]["model_params"]
self.input_param["resuming"] = True
self.multi_task = "model_dict" in self.input_param
assert not self.multi_task, "multitask mode currently not supported!"
self.type_split = self.input_param["descriptor"]["type"] in ["se_e2_a"]
self.type_map = self.input_param["type_map"]
self.dp = ModelWrapper(get_model(self.input_param).to(DEVICE))
self.dp.load_state_dict(state_dict)
self.rcut = self.dp.model["Default"].descriptor.get_rcut()
self.sec = np.cumsum(self.dp.model["Default"].descriptor.get_sel())
if str(self.model_path).endswith(".pt"):
state_dict = torch.load(model_file, map_location=env.DEVICE)
if "model" in state_dict:
state_dict = state_dict["model"]
self.input_param = state_dict["_extra_state"]["model_params"]
self.input_param["resuming"] = True
self.multi_task = "model_dict" in self.input_param
assert not self.multi_task, "multitask mode currently not supported!"
model = get_model(self.input_param).to(DEVICE)
model = torch.jit.script(model)
self.dp = ModelWrapper(model)
self.dp.load_state_dict(state_dict)
elif str(self.model_path).endswith(".pth"):
model = torch.jit.load(model_file, map_location=env.DEVICE)
self.dp = ModelWrapper(model)
else:
raise ValueError("Unknown model file format!")
self.rcut = self.dp.model["Default"].get_rcut()
self.type_map = self.dp.model["Default"].get_type_map()
if isinstance(auto_batch_size, bool):
if auto_batch_size:
self.auto_batch_size = AutoBatchSize()
Expand Down
17 changes: 17 additions & 0 deletions source/tests/pt/model/test_deeppot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import unittest
from argparse import (
Namespace,
)
from copy import (
deepcopy,
)
Expand All @@ -12,6 +15,7 @@

from deepmd.infer.deep_pot import DeepPot as DeepPotUni
from deepmd.pt.entrypoints.main import (
freeze,
get_trainer,
)
from deepmd.pt.infer.deep_eval import (
Expand Down Expand Up @@ -95,3 +99,16 @@ def test_uni(self):
dp = DeepPotUni("model.pt")
self.assertIsInstance(dp, DeepPot)
# its methods has been tested in test_dp_test


class TestDeepPotFrozen(TestDeepPot):
def setUp(self):
super().setUp()
frozen_model = "frozen_model.pth"
ns = Namespace(
model=self.model,
output=frozen_model,
head=None,
)
freeze(ns)
self.model = frozen_model

0 comments on commit 392b9e0

Please sign in to comment.