diff --git a/deepmd/backend/suffix.py b/deepmd/backend/suffix.py new file mode 100644 index 0000000000..273fbc0951 --- /dev/null +++ b/deepmd/backend/suffix.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import functools +import operator +from pathlib import ( + Path, +) +from typing import ( + Optional, + Type, + Union, +) + +from deepmd.backend.backend import ( + Backend, +) + + +def format_model_suffix( + filename: str, + feature: Optional[Backend.Feature] = None, + preferred_backend: Optional[Union[str, Type["Backend"]]] = None, + strict_prefer: Optional[bool] = None, +) -> str: + """Check and format the suffixes of a filename. + + When preferred_backend is not given, this method checks the suffix of the filename + is within the suffixes of the any backends (with the given feature) and doesn't do formating. + When preferred_backend is given, strict_prefer must be given. + If strict_prefer is True and the suffix is not within the suffixes of the preferred backend, + or strict_prefer is False and the suffix is not within the suffixes of the any backend with the given feature, + the filename will be formatted with the preferred suffix of the preferred backend. + + Parameters + ---------- + filename : str + The filename to be formatted. + feature : Backend.Feature, optional + The feature of the backend, by default None + preferred_backend : str or type of Backend, optional + The preferred backend, by default None + strict_prefer : bool, optional + Whether to strictly prefer the preferred backend, by default None + + Returns + ------- + str + The formatted filename with the correct suffix. + + Raises + ------ + ValueError + When preferred_backend is not given and the filename is not supported by any backend. + """ + if preferred_backend is not None and strict_prefer is None: + raise ValueError("strict_prefer must be given when preferred_backend is given.") + if isinstance(preferred_backend, str): + preferred_backend = Backend.get_backend(preferred_backend) + if preferred_backend is not None and strict_prefer: + all_backends = [preferred_backend] + elif feature is None: + all_backends = list(Backend.get_backends().values()) + else: + all_backends = list(Backend.get_backends_by_feature(feature).values()) + + all_suffixes = set( + functools.reduce( + operator.iconcat, [backend.suffixes for backend in all_backends], [] + ) + ) + pp = Path(filename) + current_suffix = pp.suffix + if current_suffix not in all_suffixes: + if preferred_backend is not None: + return str(pp) + preferred_backend.suffixes[0] + raise ValueError(f"Unsupported model file format: {filename}") + return filename diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py new file mode 100644 index 0000000000..9a03ac5e45 --- /dev/null +++ b/deepmd/entrypoints/main.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Common entrypoints.""" + +import argparse +from pathlib import ( + Path, +) + +from deepmd.backend.backend import ( + Backend, +) +from deepmd.backend.suffix import ( + format_model_suffix, +) +from deepmd.entrypoints.doc import ( + doc_train_input, +) +from deepmd.entrypoints.gui import ( + start_dpgui, +) +from deepmd.entrypoints.neighbor_stat import ( + neighbor_stat, +) +from deepmd.entrypoints.test import ( + test, +) +from deepmd.infer.model_devi import ( + make_model_devi, +) +from deepmd.loggers.loggers import ( + set_log_handles, +) + + +def main(args: argparse.Namespace): + """DeePMD-Kit entry point. + + Parameters + ---------- + args : List[str] or argparse.Namespace, optional + list of command line arguments, used to avoid calling from the subprocess, + as it is quite slow to import tensorflow; if Namespace is given, it will + be used directly + + Raises + ------ + RuntimeError + if no command was input + """ + set_log_handles(args.log_level, Path(args.log_path) if args.log_path else None) + + dict_args = vars(args) + + if args.command == "test": + dict_args["model"] = format_model_suffix( + dict_args["model"], + feature=Backend.Feature.DEEP_EVAL, + preferred_backend=args.backend, + strict_prefer=False, + ) + test(**dict_args) + elif args.command == "doc-train-input": + doc_train_input(**dict_args) + elif args.command == "model-devi": + dict_args["models"] = [ + format_model_suffix( + mm, + feature=Backend.Feature.DEEP_EVAL, + preferred_backend=args.backend, + strict_prefer=False, + ) + for mm in dict_args["models"] + ] + make_model_devi(**dict_args) + elif args.command == "neighbor-stat": + neighbor_stat(**dict_args) + elif args.command == "gui": + start_dpgui(**dict_args) + else: + raise ValueError(f"Unknown command: {args.command}") diff --git a/deepmd/main.py b/deepmd/main.py index 2bde6376f2..98f5ab0c6b 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -760,6 +760,28 @@ def main(): if args.backend not in BACKEND_TABLE: raise ValueError(f"Unknown backend {args.backend}") - deepmd_main = BACKENDS[args.backend]().entry_point_hook + + if args.command in ( + "test", + "doc-train-input", + "model-devi", + "neighbor-stat", + "gui", + ): + # common entrypoints + from deepmd.entrypoints.main import main as deepmd_main + elif args.command in ( + "train", + "freeze", + "transfer", + "compress", + "convert-from", + "train-nvnmd", + ): + deepmd_main = BACKENDS[args.backend]().entry_point_hook + elif args.command is None: + pass + else: + raise RuntimeError(f"unknown command {args.command}") deepmd_main(args) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 8ed9c51634..212a6824e7 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -23,21 +23,6 @@ from deepmd import ( __version__, ) -from deepmd.entrypoints.doc import ( - doc_train_input, -) -from deepmd.entrypoints.gui import ( - start_dpgui, -) -from deepmd.entrypoints.neighbor_stat import ( - neighbor_stat, -) -from deepmd.entrypoints.test import ( - test, -) -from deepmd.infer.model_devi import ( - make_model_devi, -) from deepmd.loggers.loggers import ( set_log_handles, ) @@ -281,22 +266,13 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): FLAGS = parse_args(args=args) else: FLAGS = args - dict_args = vars(FLAGS) set_log_handles(FLAGS.log_level, FLAGS.log_path, mpi_log=None) log.debug("Log handles were successfully set") - log.info("DeepMD version: %s", __version__) if FLAGS.command == "train": train(FLAGS) - elif FLAGS.command == "test": - dict_args["output"] = ( - str(Path(FLAGS.model).with_suffix(".pth")) - if Path(FLAGS.model).suffix not in (".pt", ".pth") - else FLAGS.model - ) - test(**dict_args) elif FLAGS.command == "freeze": if Path(FLAGS.checkpoint_folder).is_dir(): checkpoint_path = Path(FLAGS.checkpoint_folder) @@ -306,20 +282,6 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) freeze(FLAGS) - elif FLAGS.command == "doc-train-input": - doc_train_input(**dict_args) - elif FLAGS.command == "model-devi": - dict_args["models"] = [ - str(Path(mm).with_suffix(".pth")) - if Path(mm).suffix not in (".pb", ".pt", ".pth") - else mm - for mm in dict_args["models"] - ] - make_model_devi(**dict_args) - elif FLAGS.command == "gui": - start_dpgui(**dict_args) - elif FLAGS.command == "neighbor-stat": - neighbor_stat(**dict_args) else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/deepmd/tf/entrypoints/main.py b/deepmd/tf/entrypoints/main.py index d57b43fc7c..493e5b7aa4 100644 --- a/deepmd/tf/entrypoints/main.py +++ b/deepmd/tf/entrypoints/main.py @@ -11,6 +11,9 @@ Union, ) +from deepmd.backend.suffix import ( + format_model_suffix, +) from deepmd.main import ( get_ll, main_parser, @@ -22,12 +25,7 @@ from deepmd.tf.entrypoints import ( compress, convert, - doc_train_input, freeze, - make_model_devi, - neighbor_stat, - start_dpgui, - test, train_dp, transfer, ) @@ -73,33 +71,18 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): if args.command == "train": train_dp(**dict_args) elif args.command == "freeze": - dict_args["output"] = str(Path(dict_args["output"]).with_suffix(".pb")) + dict_args["output"] = format_model_suffix( + dict_args["output"], preferred_backend=args.backend, strict_prefer=True + ) freeze(**dict_args) - elif args.command == "test": - dict_args["model"] = str(Path(dict_args["model"]).with_suffix(".pb")) - test(**dict_args) elif args.command == "transfer": transfer(**dict_args) elif args.command == "compress": compress(**dict_args) - elif args.command == "doc-train-input": - doc_train_input(**dict_args) - elif args.command == "model-devi": - dict_args["models"] = [ - str(Path(mm).with_suffix(".pb")) - if Path(mm).suffix not in (".pb", ".pt") - else mm - for mm in dict_args["models"] - ] - make_model_devi(**dict_args) elif args.command == "convert-from": convert(**dict_args) - elif args.command == "neighbor-stat": - neighbor_stat(**dict_args) elif args.command == "train-nvnmd": # nvnmd train_nvnmd(**dict_args) - elif args.command == "gui": - start_dpgui(**dict_args) elif args.command is None: pass else: diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index c2def1677f..7ca2b6f4ab 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -651,7 +651,6 @@ def __new__(cls, *args, **kwargs): fitting_type = type(kwargs["fitting_net"]) else: raise RuntimeError("get unknown fitting type when building model") - print(fitting_type) # init model # infer model type by fitting_type if issubclass(fitting_type, EnerFitting):