Skip to content

Commit

Permalink
Merge branch 'deepmodeling:devel' into devel
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml authored Feb 22, 2024
2 parents 9744929 + cf21b7a commit e11740b
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 63 deletions.
76 changes: 76 additions & 0 deletions deepmd/backend/suffix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
import operator
from pathlib import (
Path,
)
from typing import (
Optional,
Type,
Union,
)

from deepmd.backend.backend import (
Backend,
)


def format_model_suffix(
filename: str,
feature: Optional[Backend.Feature] = None,
preferred_backend: Optional[Union[str, Type["Backend"]]] = None,
strict_prefer: Optional[bool] = None,
) -> str:
"""Check and format the suffixes of a filename.
When preferred_backend is not given, this method checks the suffix of the filename
is within the suffixes of the any backends (with the given feature) and doesn't do formating.
When preferred_backend is given, strict_prefer must be given.
If strict_prefer is True and the suffix is not within the suffixes of the preferred backend,
or strict_prefer is False and the suffix is not within the suffixes of the any backend with the given feature,
the filename will be formatted with the preferred suffix of the preferred backend.
Parameters
----------
filename : str
The filename to be formatted.
feature : Backend.Feature, optional
The feature of the backend, by default None
preferred_backend : str or type of Backend, optional
The preferred backend, by default None
strict_prefer : bool, optional
Whether to strictly prefer the preferred backend, by default None
Returns
-------
str
The formatted filename with the correct suffix.
Raises
------
ValueError
When preferred_backend is not given and the filename is not supported by any backend.
"""
if preferred_backend is not None and strict_prefer is None:
raise ValueError("strict_prefer must be given when preferred_backend is given.")
if isinstance(preferred_backend, str):
preferred_backend = Backend.get_backend(preferred_backend)
if preferred_backend is not None and strict_prefer:
all_backends = [preferred_backend]
elif feature is None:
all_backends = list(Backend.get_backends().values())
else:
all_backends = list(Backend.get_backends_by_feature(feature).values())

all_suffixes = set(
functools.reduce(
operator.iconcat, [backend.suffixes for backend in all_backends], []
)
)
pp = Path(filename)
current_suffix = pp.suffix
if current_suffix not in all_suffixes:
if preferred_backend is not None:
return str(pp) + preferred_backend.suffixes[0]
raise ValueError(f"Unsupported model file format: {filename}")
return filename
80 changes: 80 additions & 0 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Common entrypoints."""

import argparse
from pathlib import (
Path,
)

from deepmd.backend.backend import (
Backend,
)
from deepmd.backend.suffix import (
format_model_suffix,
)
from deepmd.entrypoints.doc import (
doc_train_input,
)
from deepmd.entrypoints.gui import (
start_dpgui,
)
from deepmd.entrypoints.neighbor_stat import (
neighbor_stat,
)
from deepmd.entrypoints.test import (
test,
)
from deepmd.infer.model_devi import (
make_model_devi,
)
from deepmd.loggers.loggers import (
set_log_handles,
)


def main(args: argparse.Namespace):
"""DeePMD-Kit entry point.
Parameters
----------
args : List[str] or argparse.Namespace, optional
list of command line arguments, used to avoid calling from the subprocess,
as it is quite slow to import tensorflow; if Namespace is given, it will
be used directly
Raises
------
RuntimeError
if no command was input
"""
set_log_handles(args.log_level, Path(args.log_path) if args.log_path else None)

dict_args = vars(args)

if args.command == "test":
dict_args["model"] = format_model_suffix(
dict_args["model"],
feature=Backend.Feature.DEEP_EVAL,
preferred_backend=args.backend,
strict_prefer=False,
)
test(**dict_args)
elif args.command == "doc-train-input":
doc_train_input(**dict_args)
elif args.command == "model-devi":
dict_args["models"] = [
format_model_suffix(
mm,
feature=Backend.Feature.DEEP_EVAL,
preferred_backend=args.backend,
strict_prefer=False,
)
for mm in dict_args["models"]
]
make_model_devi(**dict_args)
elif args.command == "neighbor-stat":
neighbor_stat(**dict_args)
elif args.command == "gui":
start_dpgui(**dict_args)
else:
raise ValueError(f"Unknown command: {args.command}")
24 changes: 23 additions & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,28 @@ def main():

if args.backend not in BACKEND_TABLE:
raise ValueError(f"Unknown backend {args.backend}")
deepmd_main = BACKENDS[args.backend]().entry_point_hook

if args.command in (
"test",
"doc-train-input",
"model-devi",
"neighbor-stat",
"gui",
):
# common entrypoints
from deepmd.entrypoints.main import main as deepmd_main
elif args.command in (
"train",
"freeze",
"transfer",
"compress",
"convert-from",
"train-nvnmd",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
pass
else:
raise RuntimeError(f"unknown command {args.command}")

deepmd_main(args)
38 changes: 0 additions & 38 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,6 @@
from deepmd import (
__version__,
)
from deepmd.entrypoints.doc import (
doc_train_input,
)
from deepmd.entrypoints.gui import (
start_dpgui,
)
from deepmd.entrypoints.neighbor_stat import (
neighbor_stat,
)
from deepmd.entrypoints.test import (
test,
)
from deepmd.infer.model_devi import (
make_model_devi,
)
from deepmd.loggers.loggers import (
set_log_handles,
)
Expand Down Expand Up @@ -281,22 +266,13 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
FLAGS = parse_args(args=args)
else:
FLAGS = args
dict_args = vars(FLAGS)

set_log_handles(FLAGS.log_level, FLAGS.log_path, mpi_log=None)
log.debug("Log handles were successfully set")

log.info("DeepMD version: %s", __version__)

if FLAGS.command == "train":
train(FLAGS)
elif FLAGS.command == "test":
dict_args["output"] = (
str(Path(FLAGS.model).with_suffix(".pth"))
if Path(FLAGS.model).suffix not in (".pt", ".pth")
else FLAGS.model
)
test(**dict_args)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
Expand All @@ -306,20 +282,6 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
freeze(FLAGS)
elif FLAGS.command == "doc-train-input":
doc_train_input(**dict_args)
elif FLAGS.command == "model-devi":
dict_args["models"] = [
str(Path(mm).with_suffix(".pth"))
if Path(mm).suffix not in (".pb", ".pt", ".pth")
else mm
for mm in dict_args["models"]
]
make_model_devi(**dict_args)
elif FLAGS.command == "gui":
start_dpgui(**dict_args)
elif FLAGS.command == "neighbor-stat":
neighbor_stat(**dict_args)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
29 changes: 6 additions & 23 deletions deepmd/tf/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
Union,
)

from deepmd.backend.suffix import (
format_model_suffix,
)
from deepmd.main import (
get_ll,
main_parser,
Expand All @@ -22,12 +25,7 @@
from deepmd.tf.entrypoints import (
compress,
convert,
doc_train_input,
freeze,
make_model_devi,
neighbor_stat,
start_dpgui,
test,
train_dp,
transfer,
)
Expand Down Expand Up @@ -73,33 +71,18 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if args.command == "train":
train_dp(**dict_args)
elif args.command == "freeze":
dict_args["output"] = str(Path(dict_args["output"]).with_suffix(".pb"))
dict_args["output"] = format_model_suffix(
dict_args["output"], preferred_backend=args.backend, strict_prefer=True
)
freeze(**dict_args)
elif args.command == "test":
dict_args["model"] = str(Path(dict_args["model"]).with_suffix(".pb"))
test(**dict_args)
elif args.command == "transfer":
transfer(**dict_args)
elif args.command == "compress":
compress(**dict_args)
elif args.command == "doc-train-input":
doc_train_input(**dict_args)
elif args.command == "model-devi":
dict_args["models"] = [
str(Path(mm).with_suffix(".pb"))
if Path(mm).suffix not in (".pb", ".pt")
else mm
for mm in dict_args["models"]
]
make_model_devi(**dict_args)
elif args.command == "convert-from":
convert(**dict_args)
elif args.command == "neighbor-stat":
neighbor_stat(**dict_args)
elif args.command == "train-nvnmd": # nvnmd
train_nvnmd(**dict_args)
elif args.command == "gui":
start_dpgui(**dict_args)
elif args.command is None:
pass
else:
Expand Down
1 change: 0 additions & 1 deletion deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,6 @@ def __new__(cls, *args, **kwargs):
fitting_type = type(kwargs["fitting_net"])
else:
raise RuntimeError("get unknown fitting type when building model")
print(fitting_type)
# init model
# infer model type by fitting_type
if issubclass(fitting_type, EnerFitting):
Expand Down

0 comments on commit e11740b

Please sign in to comment.