From 7f3ab3acb779ba486d8940a587d96bb57fab5c38 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 20:37:07 -0500 Subject: [PATCH 1/4] pluggable backend Signed-off-by: Jinzhe Zeng --- deepmd/backend/__init__.py | 26 ++++ deepmd/backend/backend.py | 200 ++++++++++++++++++++++++++++ deepmd/backend/dpmodel.py | 86 ++++++++++++ deepmd/backend/pytorch.py | 95 +++++++++++++ deepmd/backend/tensorflow.py | 104 +++++++++++++++ deepmd/entrypoints/neighbor_stat.py | 21 ++- deepmd/infer/backend.py | 34 ----- deepmd/infer/deep_eval.py | 21 +-- deepmd/main.py | 52 ++++---- source/tests/consistent/common.py | 11 +- 10 files changed, 557 insertions(+), 93 deletions(-) create mode 100644 deepmd/backend/__init__.py create mode 100644 deepmd/backend/backend.py create mode 100644 deepmd/backend/dpmodel.py create mode 100644 deepmd/backend/pytorch.py create mode 100644 deepmd/backend/tensorflow.py delete mode 100644 deepmd/infer/backend.py diff --git a/deepmd/backend/__init__.py b/deepmd/backend/__init__.py new file mode 100644 index 0000000000..c0e40063a4 --- /dev/null +++ b/deepmd/backend/__init__.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Backends. + +Avoid directly importing third-party libraries in this module for performance. +""" +# copy from dpdata +import importlib +from importlib import ( + metadata, +) +from pathlib import ( + Path, +) + +PACKAGE_BASE = "deepmd.backend" +NOT_LOADABLE = ("__init__.py",) + +for module_file in Path(__file__).parent.glob("*.py"): + if module_file.name not in NOT_LOADABLE: + module_name = f".{module_file.stem}" + importlib.import_module(module_name, PACKAGE_BASE) + +# https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html +eps = metadata.entry_points(group="deepmd.backend") +for ep in eps: + plugin = ep.load() diff --git a/deepmd/backend/backend.py b/deepmd/backend/backend.py new file mode 100644 index 0000000000..7fcf7c9f40 --- /dev/null +++ b/deepmd/backend/backend.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + abstractmethod, +) +from enum import ( + Flag, + auto, +) +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, + Dict, + List, + Type, +) + +from deepmd.utils.plugin import ( + Plugin, + PluginVariant, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +class Backend(PluginVariant): + r"""General backend class. + + Examples + -------- + >>> @Backend.register("tf") + >>> @Backend.register("tensorflow") + >>> class TensorFlowBackend(Backend): + ... pass + """ + + __plugins = Plugin() + + @staticmethod + def register(key: str) -> Callable[[object], object]: + """Register a backend plugin. + + Parameters + ---------- + key : str + the key of a backend + + Returns + ------- + Callable[[object], object] + the decorator to register backend + """ + return Backend.__plugins.register(key.lower()) + + @staticmethod + def get_backend(key: str) -> Type["Backend"]: + """Get the backend by key. + + Parameters + ---------- + key : str + the key of a backend + + Returns + ------- + Backend + the backend + """ + try: + backend = Backend.__plugins.get_plugin(key.lower()) + except KeyError: + raise KeyError(f"Backend {key} is not registered.") + assert isinstance(backend, type) + return backend + + @staticmethod + def get_backends() -> Dict[str, Type["Backend"]]: + """Get all the registered backend names. + + Returns + ------- + list + all the registered backends + """ + return Backend.__plugins.plugins + + @staticmethod + def get_backends_by_feature( + feature: "Backend.Feature", + ) -> Dict[str, Type["Backend"]]: + """Get all the registered backend names with a specific feature. + + Parameters + ---------- + feature : Backend.Feature + the feature flag + + Returns + ------- + list + all the registered backends with the feature + """ + return { + key: backend + for key, backend in Backend.__plugins.plugins.items() + if backend.features & feature + } + + @staticmethod + def detect_backend_by_model(filename: str) -> Type["Backend"]: + """Detect the backend of the given model file. + + Parameters + ---------- + filename : str + The model file name + """ + filename = str(filename).lower() + for backend in Backend.get_backends().values(): + if filename.endswith(backend.suffixes[0]): + return backend + raise ValueError(f"Cannot detect the backend of the model file {filename}.") + + class Feature(Flag): + """Feature flag to indicate whether the backend supports certain features.""" + + ENTRY_POINT = auto() + """Support entry point hook.""" + DEEP_EVAL = auto() + """Support Deep Eval backend.""" + NEIGHBOR_STAT = auto() + """Support neighbor statistics.""" + + name: ClassVar[str] = "Unknown" + """The formal name of the backend. + + To be consistent, this name should be also registered in the plugin system.""" + + features: ClassVar[Feature] = Feature(0) + """The features of the backend.""" + suffixes: ClassVar[List[str]] = [] + """The supported suffixes of the saved model. + + The first element is considered as the default suffix.""" + + @abstractmethod + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + + @property + @abstractmethod + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + pass + + @property + @abstractmethod + def deep_eval(self) -> Type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + pass + + @property + @abstractmethod + def neighbor_stat(self) -> Type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + pass diff --git a/deepmd/backend/dpmodel.py b/deepmd/backend/dpmodel.py new file mode 100644 index 0000000000..2e24f36447 --- /dev/null +++ b/deepmd/backend/dpmodel.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, + List, + Type, +) + +from deepmd.backend.backend import ( + Backend, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("dp") +@Backend.register("dpmodel") +@Backend.register("np") +@Backend.register("numpy") +class DPModelBackend(Backend): + """DPModel backend that uses NumPy as the reference implementation.""" + + name = "DPModel" + """The formal name of the backend.""" + features: ClassVar[Backend.Feature] = Backend.Feature.NEIGHBOR_STAT + """The features of the backend.""" + suffixes: ClassVar[List[str]] = ["dp"] + """The suffixes of the backend.""" + + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + return True + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + raise NotImplementedError(f"Unsupported backend: {self.name}") + + @property + def deep_eval(self) -> Type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + raise NotImplementedError(f"Unsupported backend: {self.name}") + + @property + def neighbor_stat(self) -> Type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + from deepmd.dpmodel.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat diff --git a/deepmd/backend/pytorch.py b/deepmd/backend/pytorch.py new file mode 100644 index 0000000000..146ce7f9f2 --- /dev/null +++ b/deepmd/backend/pytorch.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from importlib.util import ( + find_spec, +) +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, + List, + Type, +) + +from deepmd.backend.backend import ( + Backend, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("pt") +@Backend.register("pytorch") +class TensorFlowBackend(Backend): + """TensorFlow backend.""" + + name = "PyTorch" + """The formal name of the backend.""" + features: ClassVar[Backend.Feature] = ( + Backend.Feature.ENTRY_POINT + | Backend.Feature.DEEP_EVAL + | Backend.Feature.NEIGHBOR_STAT + ) + """The features of the backend.""" + suffixes: ClassVar[List[str]] = ["pth", "pt"] + """The suffixes of the backend.""" + + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + return find_spec("torch") is not None + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + from deepmd.pt.entrypoints.main import main as deepmd_main + + return deepmd_main + + @property + def deep_eval(self) -> Type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT + + return DeepEvalPT + + @property + def neighbor_stat(self) -> Type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + from deepmd.pt.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat diff --git a/deepmd/backend/tensorflow.py b/deepmd/backend/tensorflow.py new file mode 100644 index 0000000000..b37b887bec --- /dev/null +++ b/deepmd/backend/tensorflow.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from importlib.util import ( + find_spec, +) +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, + List, + Type, +) + +from deepmd.backend.backend import ( + Backend, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("tf") +@Backend.register("tensorflow") +class TensorFlowBackend(Backend): + """TensorFlow backend.""" + + name = "TensorFlow" + """The formal name of the backend.""" + features: ClassVar[Backend.Feature] = ( + Backend.Feature.ENTRY_POINT + | Backend.Feature.DEEP_EVAL + | Backend.Feature.NEIGHBOR_STAT + ) + """The features of the backend.""" + suffixes: ClassVar[List[str]] = ["pb"] + """The suffixes of the backend.""" + + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + # deepmd.env imports expensive numpy + # avoid import outside the method + from deepmd.env import ( + GLOBAL_CONFIG, + ) + + return ( + find_spec("tensorflow") is not None + and GLOBAL_CONFIG["enable_tensorflow"] != "0" + ) + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + from deepmd.tf.entrypoints.main import main as deepmd_main + + return deepmd_main + + @property + def deep_eval(self) -> Type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + from deepmd.tf.infer.deep_eval import DeepEval as DeepEvalTF + + return DeepEvalTF + + @property + def neighbor_stat(self) -> Type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + from deepmd.tf.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat diff --git a/deepmd/entrypoints/neighbor_stat.py b/deepmd/entrypoints/neighbor_stat.py index f5ce0f839d..8a496fb6f0 100644 --- a/deepmd/entrypoints/neighbor_stat.py +++ b/deepmd/entrypoints/neighbor_stat.py @@ -4,6 +4,9 @@ List, ) +from deepmd.backend.backend import ( + Backend, +) from deepmd.common import ( expand_sys_str, ) @@ -69,20 +72,12 @@ def neighbor_stat( min_nbor_dist: 0.6599510670195264 max_nbor_size: [23, 26, 19, 16, 2, 2, 1, 1, 72, 37, 5, 0, 31, 29, 1, 21, 20, 5] """ - if backend == "tensorflow": - from deepmd.tf.utils.neighbor_stat import ( - NeighborStat, - ) - elif backend == "pytorch": - from deepmd.pt.utils.neighbor_stat import ( - NeighborStat, - ) - elif backend == "numpy": - from deepmd.dpmodel.utils.neighbor_stat import ( - NeighborStat, - ) - else: + backends = Backend.get_backends_by_feature(Backend.Feature.NEIGHBOR_STAT) + try: + backend_obj = backends[backend]() + except KeyError: raise ValueError(f"Invalid backend {backend}") + NeighborStat = backend_obj.neighbor_stat all_sys = expand_sys_str(system) if not len(all_sys): raise RuntimeError("Did not find valid system") diff --git a/deepmd/infer/backend.py b/deepmd/infer/backend.py deleted file mode 100644 index 26eef22eb4..0000000000 --- a/deepmd/infer/backend.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from enum import ( - Enum, -) - - -class DPBackend(Enum): - """DeePMD-kit backend.""" - - TensorFlow = 1 - PyTorch = 2 - Paddle = 3 - Unknown = 4 - - -def detect_backend(filename: str) -> DPBackend: - """Detect the backend of the given model file. - - Parameters - ---------- - filename : str - The model file name - """ - filename = str(filename).lower() - if filename.endswith(".pb"): - return DPBackend.TensorFlow - elif filename.endswith(".pth") or filename.endswith(".pt"): - return DPBackend.PyTorch - elif filename.endswith(".pdmodel"): - return DPBackend.Paddle - return DPBackend.Unknown - - -__all__ = ["DPBackend", "detect_backend"] diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 3b1eceb16d..35d170cdab 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -16,6 +16,9 @@ import numpy as np +from deepmd.backend.backend import ( + Backend, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, ModelOutputDef, @@ -24,11 +27,6 @@ AutoBatchSize, ) -from .backend import ( - DPBackend, - detect_backend, -) - if TYPE_CHECKING: import ase.neighborlist @@ -89,17 +87,8 @@ def __init__( def __new__(cls, model_file: str, *args, **kwargs): if cls is DeepEvalBackend: - backend = detect_backend(model_file) - if backend == DPBackend.TensorFlow: - from deepmd.tf.infer.deep_eval import DeepEval as DeepEvalTF - - return super().__new__(DeepEvalTF) - elif backend == DPBackend.PyTorch: - from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT - - return super().__new__(DeepEvalPT) - else: - raise NotImplementedError("Unsupported backend: " + str(backend)) + backend = Backend.detect_backend_by_model(model_file) + return super().__new__(backend().deep_eval) return super().__new__(cls) @abstractmethod diff --git a/deepmd/main.py b/deepmd/main.py index d6714e1e26..d31cab30c2 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -8,9 +8,18 @@ import logging import os import textwrap +from collections import ( + defaultdict, +) from typing import ( + Dict, List, Optional, + Type, +) + +from deepmd.backend.backend import ( + Backend, ) try: @@ -46,12 +55,10 @@ class RawTextArgumentDefaultsHelpFormatter( """This formatter is used to print multile-line help message with default value.""" -BACKEND_TABLE = { - "tensorflow": "tensorflow", - "tf": "tensorflow", - "pytorch": "pytorch", - "pt": "pytorch", -} +BACKENDS: Dict[str, Type[Backend]] = Backend.get_backends_by_feature( + Backend.Feature.ENTRY_POINT +) +BACKEND_TABLE: Dict[str, str] = {kk: vv.name.lower() for kk, vv in BACKENDS.items()} class BackendOption(argparse.Action): @@ -102,20 +109,18 @@ def main_parser() -> argparse.ArgumentParser: "DP_BACKEND." ), ) - parser_backend.add_argument( - "--tf", - action="store_const", - dest="backend", - const="tensorflow", - help="Alias for --backend tensorflow", - ) - parser_backend.add_argument( - "--pt", - action="store_const", - dest="backend", - const="pytorch", - help="Alias for --backend pytorch", - ) + + BACKEND_ALIAS: Dict[str, List[str]] = defaultdict(list) + for alias, backend in BACKEND_TABLE.items(): + BACKEND_ALIAS[backend].append(alias) + for backend, alias in BACKEND_ALIAS.items(): + parser_backend.add_argument( + *[f"--{aa}" for aa in alias], + action="store_const", + dest="backend", + const=backend, + help=f"Alias for --backend {backend}", + ) subparsers = parser.add_subparsers(title="Valid subcommands", dest="command") @@ -752,11 +757,8 @@ def main(): """ args = parse_args() - if args.backend == "tensorflow": - from deepmd.tf.entrypoints.main import main as deepmd_main - elif args.backend == "pytorch": - from deepmd.pt.entrypoints.main import main as deepmd_main - else: + if args.backend not in BACKEND_TABLE: raise ValueError(f"Unknown backend {args.backend}") + deepmd_main = BACKENDS[args.backend]().entry_point_hook deepmd_main(args) diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index e5633726ef..5056b0c2f4 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -9,9 +9,6 @@ from enum import ( Enum, ) -from importlib.util import ( - find_spec, -) from typing import ( Any, Callable, @@ -29,8 +26,12 @@ Argument, ) -INSTALLED_TF = find_spec("tensorflow") is not None -INSTALLED_PT = find_spec("torch") is not None +from deepmd.backend.tensorflow import ( + Backend, +) + +INSTALLED_TF = Backend.get_backend("tensorflow")().is_available() +INSTALLED_PT = Backend.get_backend("pytorch")().is_available() if os.environ.get("CI") and not (INSTALLED_TF and INSTALLED_PT): raise ImportError("TensorFlow or PyTorch should be tested in the CI") From e0dc2f7bc9da3d592cba056e71568f2dd2503fa8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 20:50:56 -0500 Subject: [PATCH 2/4] fix py38 compatibility Signed-off-by: Jinzhe Zeng --- deepmd/backend/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/backend/__init__.py b/deepmd/backend/__init__.py index c0e40063a4..c653d13fcc 100644 --- a/deepmd/backend/__init__.py +++ b/deepmd/backend/__init__.py @@ -21,6 +21,9 @@ importlib.import_module(module_name, PACKAGE_BASE) # https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html -eps = metadata.entry_points(group="deepmd.backend") +try: + eps = metadata.entry_points(group="deepmd.backend") +except TypeError: + eps = metadata.entry_points().get("deepmd.backend", []) for ep in eps: plugin = ep.load() From 8485d9cffb2c4886858a311e51f28dcb777224c9 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 20:52:20 -0500 Subject: [PATCH 3/4] fix codeql warning Signed-off-by: Jinzhe Zeng --- deepmd/backend/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/backend/__init__.py b/deepmd/backend/__init__.py index c653d13fcc..8969edd480 100644 --- a/deepmd/backend/__init__.py +++ b/deepmd/backend/__init__.py @@ -4,8 +4,8 @@ Avoid directly importing third-party libraries in this module for performance. """ # copy from dpdata -import importlib from importlib import ( + import_module, metadata, ) from pathlib import ( @@ -18,7 +18,7 @@ for module_file in Path(__file__).parent.glob("*.py"): if module_file.name not in NOT_LOADABLE: module_name = f".{module_file.stem}" - importlib.import_module(module_name, PACKAGE_BASE) + import_module(module_name, PACKAGE_BASE) # https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html try: From 908d74800b9ed95d4db7d65d7584b48fea16edaf Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 21:20:22 -0500 Subject: [PATCH 4/4] fix suffix behaviors Signed-off-by: Jinzhe Zeng --- deepmd/backend/backend.py | 5 +++-- deepmd/backend/dpmodel.py | 2 +- deepmd/backend/pytorch.py | 2 +- deepmd/backend/tensorflow.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/deepmd/backend/backend.py b/deepmd/backend/backend.py index 7fcf7c9f40..179b2e556a 100644 --- a/deepmd/backend/backend.py +++ b/deepmd/backend/backend.py @@ -127,8 +127,9 @@ def detect_backend_by_model(filename: str) -> Type["Backend"]: """ filename = str(filename).lower() for backend in Backend.get_backends().values(): - if filename.endswith(backend.suffixes[0]): - return backend + for suffix in backend.suffixes: + if filename.endswith(suffix): + return backend raise ValueError(f"Cannot detect the backend of the model file {filename}.") class Feature(Flag): diff --git a/deepmd/backend/dpmodel.py b/deepmd/backend/dpmodel.py index 2e24f36447..8745ca6d5a 100644 --- a/deepmd/backend/dpmodel.py +++ b/deepmd/backend/dpmodel.py @@ -35,7 +35,7 @@ class DPModelBackend(Backend): """The formal name of the backend.""" features: ClassVar[Backend.Feature] = Backend.Feature.NEIGHBOR_STAT """The features of the backend.""" - suffixes: ClassVar[List[str]] = ["dp"] + suffixes: ClassVar[List[str]] = [".dp"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/backend/pytorch.py b/deepmd/backend/pytorch.py index 146ce7f9f2..4c0b0699f9 100644 --- a/deepmd/backend/pytorch.py +++ b/deepmd/backend/pytorch.py @@ -40,7 +40,7 @@ class TensorFlowBackend(Backend): | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[List[str]] = ["pth", "pt"] + suffixes: ClassVar[List[str]] = [".pth", ".pt"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/backend/tensorflow.py b/deepmd/backend/tensorflow.py index b37b887bec..80569afa97 100644 --- a/deepmd/backend/tensorflow.py +++ b/deepmd/backend/tensorflow.py @@ -40,7 +40,7 @@ class TensorFlowBackend(Backend): | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[List[str]] = ["pb"] + suffixes: ClassVar[List[str]] = [".pb"] """The suffixes of the backend.""" def is_available(self) -> bool: